chore: Initial Commit
This commit is contained in:
commit
d2e3d24d2c
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
Lab3/dataset/
|
||||||
|
Lab3/mae
|
||||||
|
Lab3/masked
|
||||||
|
Lab3/st
|
||||||
|
Lab3/st2
|
||||||
|
*.pth
|
||||||
|
.ipynb_checkpoints
|
2
Lab1_2/A1_template.csv
Normal file
2
Lab1_2/A1_template.csv
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
URN,Q,K,V,ANSWER
|
||||||
|
6644818,"[[-0.19737370312213898, -1.0540887117385864, 0.02383515052497387, 0.46185705065727234], [-1.2415547370910645, 0.8366656303405762, 0.3741966784000397, 0.9099264740943909], [0.3436168134212494, 0.6154376268386841, 1.1926648616790771, 1.6477248668670654]]","[[1.9663442373275757, 0.15551914274692535, -0.8715013861656189, 0.32070425152778625], [-5.85474967956543, 1.7047394514083862, -1.0024793148040771, 1.3307985067367554], [0.06319630891084671, -2.030783176422119, -5.436811447143555, -0.42979586124420166]]","[[-82.127197265625, 0.9534303545951843, -28.78610610961914, -10.762138366699219], [-16.467313766479492, 60.92831802368164, -36.08392333984375, 31.648052215576172], [20.485767364501953, 45.4570198059082, 15.208494186401367, 31.43212890625]]","[[-7.56060266494751, 40.530540466308594, -4.961359024047852, 23.440505981445312], [-16.6014461517334, 60.75471496582031, -36.01152420043945, 31.536420822143555], [-50.659423828125, 29.360170364379883, -31.904930114746094, 9.3984956741333]]"
|
|
2
Lab1_2/A1_template_template.csv
Normal file
2
Lab1_2/A1_template_template.csv
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
URN,Q,K,V,ANSWER
|
||||||
|
Copy your URN here,Copy your Q value here,Copy your K value here,Copy your V value here,Fill in your Answer Here
|
|
50
Lab1_2/Assignment1.csv
Normal file
50
Lab1_2/Assignment1.csv
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
URN,Q,K,V,ANSWER
|
||||||
|
6424515,"[[61.812015533447266, 40.04122543334961, -1.3776825666427612, 5.644913196563721], [-4.9865241050720215, -9.148029327392578, -90.78352355957031, 27.30191993713379], [21.7161808013916, -63.25381851196289, -28.20044708251953, 14.372629165649414]]","[[-26.79865264892578, 42.945430755615234, -14.29906177520752, 10.068202018737793], [-18.409151077270508, 16.403100967407227, 14.94759750366211, 6.1012983322143555], [-3.0599536895751953, -7.43846321105957, -29.793472290039062, -6.154738903045654]]","[[0.6668468117713928, 0.0321093387901783, 0.06967663019895554, -1.0507230758666992], [0.5716139674186707, -0.160260871052742, -0.08285751193761826, -0.5788695812225342], [0.8007872700691223, -0.27879026532173157, 1.2085530757904053, 1.4593729972839355]]",
|
||||||
|
6483559,"[[5.575249671936035, 23.299489974975586, -10.568557739257812, 14.052878379821777], [1.1961194276809692, 7.634078502655029, -3.072789430618286, 19.930212020874023], [10.905213356018066, 5.266031265258789, 3.967888355255127, -19.720064163208008]]","[[7.83995246887207, 3.862151622772217, 3.054084300994873, 1.4503860473632812], [9.037196159362793, 7.029181480407715, -0.9516154527664185, -6.377782821655273], [-3.326775550842285, -1.8302644491195679, -5.7486701011657715, 8.669612884521484]]","[[-32.508087158203125, 50.03919982910156, -20.734249114990234, 9.130086898803711], [25.76167869567871, -12.930599212646484, -10.532285690307617, -23.300447463989258], [12.573369026184082, -76.31478881835938, 0.0935261994600296, -23.858386993408203]]",
|
||||||
|
6488397,"[[2.6853091716766357, 17.725229263305664, 1.788853645324707, 0.6750410795211792], [-32.531497955322266, 18.576560974121094, -34.81031036376953, -10.207343101501465], [-19.915924072265625, -9.083245277404785, -4.631546974182129, 25.35552406311035]]","[[-29.389095306396484, 24.075138092041016, -9.420976638793945, -7.860787391662598], [-13.026573181152344, -20.35844612121582, -4.833100318908691, 16.232547760009766], [-24.538742065429688, 2.2693328857421875, -13.806289672851562, 33.24972915649414]]","[[-3.366933822631836, -1.0369418859481812, -4.604970932006836, 2.527782917022705], [-3.817992925643921, -3.409868001937866, 8.358521461486816, -3.794895648956299], [-1.9248569011688232, 3.9590156078338623, -0.23751775920391083, -5.782866954803467]]",
|
||||||
|
6541000,"[[25.85124397277832, -14.847068786621094, -8.007427215576172, 13.616039276123047], [22.194473266601562, -41.049625396728516, -48.59806442260742, -30.103748321533203], [-0.9766722321510315, 16.258834838867188, 5.409304141998291, -34.920753479003906]]","[[-19.665332794189453, -9.125407218933105, 25.826215744018555, -24.121002197265625], [-0.6768970489501953, 32.871368408203125, 28.883134841918945, 15.97001838684082], [-32.61582946777344, -26.842838287353516, 29.825098037719727, 25.55153465270996]]","[[-22.1315860748291, -19.117671966552734, -86.41238403320312, -41.470970153808594], [-67.3498306274414, -20.870065689086914, 46.31827926635742, -49.43082809448242], [28.707395553588867, 49.85984802246094, -20.697355270385742, -19.470643997192383]]",
|
||||||
|
6564898,"[[17.990629196166992, 18.082250595092773, 6.080262660980225, -1.9313312768936157], [-1.1121807098388672, 14.952383041381836, 5.227292060852051, 11.955268859863281], [-12.968427658081055, 10.1666841506958, 9.647360801696777, 10.25912094116211]]","[[33.59651184082031, -2.7849419116973877, -4.290157318115234, 16.049331665039062], [16.138713836669922, -3.1907596588134766, 1.9617745876312256, -13.963973999023438], [1.2874808311462402, -11.218521118164062, -4.645533561706543, -21.415369033813477]]","[[29.302244186401367, 37.714359283447266, -14.346443176269531, 26.861482620239258], [15.229442596435547, -30.665781021118164, 27.858675003051758, -3.7787418365478516], [19.023386001586914, 33.741153717041016, 17.80762481689453, -14.525671005249023]]",
|
||||||
|
6595203,"[[-4.682967185974121, -0.46032536029815674, 1.9287296533584595, 1.098872423171997], [1.813373327255249, -0.17624400556087494, -7.465083122253418, 4.692303657531738], [-11.090826034545898, 7.349782943725586, 4.164590835571289, -4.623814582824707]]","[[52.20319747924805, 39.19321060180664, 4.55007791519165, 32.2530403137207], [61.92286682128906, -44.482208251953125, -35.478302001953125, -68.6395263671875], [20.798810958862305, -43.60276412963867, 8.565412521362305, 12.54694938659668]]","[[0.6951367855072021, -0.21053913235664368, 1.9876152276992798, 0.10447879880666733], [0.9846767783164978, 0.6022341847419739, -0.6896607279777527, -1.6564579010009766], [-0.7948723435401917, 0.6899239420890808, -1.8456658124923706, 0.6393752098083496]]",
|
||||||
|
6595493,"[[0.1511625498533249, -0.9745533466339111, -2.2466001510620117, 3.180349349975586], [-5.036211967468262, -3.2606935501098633, 0.8353381156921387, -0.9949823617935181], [2.559250593185425, 1.6193833351135254, -1.08794367313385, -0.50386643409729]]","[[0.7303028106689453, -3.6174004077911377, -2.7354886531829834, 2.467529296875], [-2.1577024459838867, -0.8152199387550354, 9.242465019226074, -0.7949981689453125], [-1.6239731311798096, 5.8632731437683105, -7.212184906005859, 1.169894814491272]]","[[-8.279622077941895, 6.367185115814209, -0.10335130989551544, -15.839896202087402], [38.9275016784668, -5.544412612915039, -17.214385986328125, -35.62171173095703], [7.223878383636475, 2.735114812850952, 7.348245143890381, 13.252175331115723]]",
|
||||||
|
6620065,"[[-8.823065757751465, -0.38451871275901794, 11.639081001281738, -1.696522831916809], [10.205957412719727, -6.988275051116943, -16.136812210083008, -3.569871425628662], [-12.256014823913574, -10.282251358032227, -0.021453989669680595, -4.233187198638916]]","[[-7.298791408538818, 4.976022243499756, 12.335108757019043, 28.913663864135742], [14.395263671875, 20.777713775634766, 21.951862335205078, 10.135398864746094], [-12.99707317352295, 32.36512756347656, -23.65703582763672, -22.85142707824707]]","[[-21.05515480041504, -12.511702537536621, 6.4885172843933105, 7.027595520019531], [10.353309631347656, 8.395435333251953, 2.216571807861328, -2.703104257583618], [-1.446360468864441, 3.8767664432525635, 9.25796127319336, 7.1006669998168945]]",
|
||||||
|
6621031,"[[75.76603698730469, 85.38924407958984, 33.53232192993164, 99.98907470703125], [-65.45123291015625, 6.995937824249268, 0.8409870862960815, -67.66314697265625], [20.245336532592773, 56.575958251953125, -21.347116470336914, 72.61317443847656]]","[[-29.93255043029785, -17.499338150024414, -73.63774108886719, 46.85200119018555], [4.411999702453613, -2.149071216583252, -60.68436050415039, -49.36929702758789], [-48.77145004272461, 44.28728103637695, -33.22221374511719, -71.23680877685547]]","[[-23.651613235473633, -21.843677520751953, -11.714576721191406, 8.189347267150879], [-20.203529357910156, -0.08122234791517258, 36.678199768066406, 7.1701836585998535], [-3.1512069702148438, 13.690710067749023, -15.387824058532715, -37.97343063354492]]",
|
||||||
|
6622936,"[[-8.63397216796875, 4.2975921630859375, -4.522152423858643, -2.323275327682495], [8.693713188171387, -3.304124355316162, 6.605628490447998, -1.9376112222671509], [-2.9686343669891357, -1.4563707113265991, 4.499622821807861, 1.661992073059082]]","[[-2.6573855876922607, 17.252361297607422, 13.382349967956543, -69.54588317871094], [33.054359436035156, -5.027382850646973, 30.04429054260254, -102.58087158203125], [-15.303414344787598, -38.95169448852539, -22.762155532836914, 7.357682228088379]]","[[-35.71727752685547, -27.900691986083984, 37.486934661865234, 10.394813537597656], [21.43458366394043, -68.33717346191406, -45.12615966796875, -24.56119728088379], [-16.092315673828125, 69.08224487304688, -8.384400367736816, 16.746034622192383]]",
|
||||||
|
6623139,"[[10.012069702148438, 14.790369033813477, 10.731794357299805, -20.580612182617188], [8.332863807678223, 1.4901442527770996, -28.56690216064453, -4.9121785163879395], [31.787015914916992, -15.51191520690918, 29.45456314086914, 15.340608596801758]]","[[17.096817016601562, 3.979304552078247, -6.065049171447754, -7.021650314331055], [22.048603057861328, 16.307952880859375, -8.492615699768066, 3.741187810897827], [20.787437438964844, -20.73921012878418, 8.975601196289062, 1.641739010810852]]","[[24.196489334106445, 10.123150825500488, -16.860660552978516, 2.4109020233154297], [11.026750564575195, -0.042206697165966034, 3.502345561981201, 34.46019744873047], [44.30955123901367, -31.93711280822754, -6.689527988433838, -17.10219955444336]]",
|
||||||
|
6627063,"[[-1.0590970516204834, 0.7565467357635498, 0.20126891136169434, -0.6364921927452087], [0.004584392067044973, -2.307931900024414, 1.4653981924057007, 0.8097707629203796], [0.4444141685962677, -1.1070520877838135, -0.4160226285457611, 1.122387170791626]]","[[-13.375662803649902, 8.191596031188965, -32.305240631103516, 44.86949157714844], [15.18038272857666, -61.08460998535156, -41.45825958251953, -32.916629791259766], [48.51971435546875, 32.51784133911133, 1.7294328212738037, -62.1268310546875]]","[[-36.20771408081055, 26.2475643157959, 8.842000961303711, 43.27354431152344], [-62.30213165283203, 41.48006820678711, -6.880943775177002, -21.451093673706055], [-5.928592205047607, -0.6987577080726624, -12.249312400817871, 48.32103729248047]]",
|
||||||
|
6634908,"[[-50.7746696472168, -85.51974487304688, -10.108343124389648, 17.88229751586914], [-57.28345489501953, 28.485719680786133, -44.194435119628906, -51.714439392089844], [-53.68115997314453, 9.014735221862793, 45.90130615234375, 30.313220977783203]]","[[-13.381082534790039, -5.4600396156311035, -23.924701690673828, 49.98859786987305], [-29.438074111938477, 13.526389122009277, 5.885373115539551, -22.310937881469727], [-32.35851287841797, 15.571061134338379, -17.888227462768555, -8.200702667236328]]","[[-0.16054195165634155, 7.1971588134765625, -17.90960121154785, 27.956819534301758], [-10.453465461730957, 9.157516479492188, 8.646942138671875, -25.173913955688477], [38.51534652709961, -1.6076290607452393, 5.644974708557129, 20.121034622192383]]",
|
||||||
|
6635583,"[[-12.722966194152832, 71.30017852783203, 17.90323257446289, -12.047688484191895], [3.3535823822021484, 78.04386901855469, -50.834327697753906, 54.36956787109375], [35.454402923583984, -20.053314208984375, -25.02804183959961, -47.87135314941406]]","[[-43.83270263671875, 32.041934967041016, 11.113154411315918, -20.175634384155273], [32.75693130493164, 17.827775955200195, 4.969196319580078, -20.265155792236328], [-43.46480178833008, 12.0179443359375, 40.43904113769531, -60.19890594482422]]","[[4.519973278045654, -6.386752605438232, -3.6178507804870605, 6.750082015991211], [2.758470058441162, 17.978330612182617, -3.6265249252319336, -11.85763168334961], [7.1331305503845215, -1.1875079870224, 3.5858330726623535, -6.4079179763793945]]",
|
||||||
|
6638234,"[[-4.133094787597656, 5.524759769439697, -0.41305333375930786, -3.8603403568267822], [2.0264086723327637, 3.7351791858673096, 4.967355251312256, 5.839091777801514], [2.6239752769470215, -6.777492523193359, -5.668360233306885, -0.9872243404388428]]","[[129.13450622558594, 40.467010498046875, -4.255222797393799, -26.798913955688477], [-1.1929067373275757, 49.3392333984375, -22.62386131286621, 27.016782760620117], [-69.80500793457031, 31.683269500732422, -89.05852508544922, -3.53053879737854]]","[[1.06186842918396, 14.646303176879883, 5.54833984375, -10.574101448059082], [-0.4040781557559967, 1.0404623746871948, 9.219944953918457, -8.739606857299805], [-6.794262409210205, 1.0300147533416748, -11.154375076293945, 5.335386276245117]]",
|
||||||
|
6640106,"[[3.3120908737182617, 11.714212417602539, -3.5947165489196777, -5.550578594207764], [9.422024726867676, 6.105061054229736, -2.6055192947387695, -4.635824680328369], [-8.967622756958008, 1.748089075088501, 4.3646955490112305, 2.2075018882751465]]","[[49.388126373291016, -3.505145788192749, 6.966372013092041, 20.567304611206055], [9.32462215423584, 30.148683547973633, -1.962703824043274, -9.71722412109375], [-96.27658081054688, 24.82595443725586, 108.67559051513672, -20.08876609802246]]","[[-25.174936294555664, 9.584662437438965, -25.932655334472656, -26.213214874267578], [17.62385368347168, -25.324230194091797, -35.68978500366211, -26.302963256835938], [12.274901390075684, 16.87058448791504, -14.158609390258789, 8.233379364013672]]",
|
||||||
|
6644818,"[[-0.19737370312213898, -1.0540887117385864, 0.02383515052497387, 0.46185705065727234], [-1.2415547370910645, 0.8366656303405762, 0.3741966784000397, 0.9099264740943909], [0.3436168134212494, 0.6154376268386841, 1.1926648616790771, 1.6477248668670654]]","[[1.9663442373275757, 0.15551914274692535, -0.8715013861656189, 0.32070425152778625], [-5.85474967956543, 1.7047394514083862, -1.0024793148040771, 1.3307985067367554], [0.06319630891084671, -2.030783176422119, -5.436811447143555, -0.42979586124420166]]","[[-82.127197265625, 0.9534303545951843, -28.78610610961914, -10.762138366699219], [-16.467313766479492, 60.92831802368164, -36.08392333984375, 31.648052215576172], [20.485767364501953, 45.4570198059082, 15.208494186401367, 31.43212890625]]",
|
||||||
|
6647000,"[[17.123088836669922, 0.7197161912918091, 67.95402526855469, 30.830045700073242], [1.918267011642456, 3.1925175189971924, -60.516944885253906, 33.09083557128906], [-17.23439598083496, 4.878037452697754, -27.22907829284668, 44.515987396240234]]","[[13.926458358764648, 41.11029815673828, 3.837980031967163, 29.635908126831055], [-45.734840393066406, 52.18793487548828, 5.066276550292969, 11.72782039642334], [-97.8349838256836, 28.44172477722168, -43.70535659790039, -25.272418975830078]]","[[-4.117476940155029, 2.6530921459198, -2.3165719509124756, 2.692505359649658], [-4.646360397338867, -6.495508670806885, -2.042623281478882, -4.2362542152404785], [-6.218053817749023, -5.21392822265625, 4.337059020996094, 5.960870742797852]]",
|
||||||
|
6650398,"[[15.136759757995605, -43.95924377441406, -113.8025894165039, 75.7243423461914], [-53.820682525634766, 6.7568206787109375, 11.68793773651123, -59.304115295410156], [17.12495231628418, -80.94425964355469, -24.56743621826172, 72.69660949707031]]","[[-44.538997650146484, 17.452163696289062, -22.793365478515625, -19.52366828918457], [22.004854202270508, -30.501188278198242, 17.9410343170166, -11.477399826049805], [13.915644645690918, -3.8742470741271973, -20.8011531829834, 10.137035369873047]]","[[94.21515655517578, 38.48592758178711, 8.827954292297363, -11.255606651306152], [9.103065490722656, -26.855743408203125, -58.49977111816406, 56.034507751464844], [36.73604202270508, 72.35386657714844, -5.083021640777588, -91.17439270019531]]",
|
||||||
|
6654031,"[[-0.3339273929595947, -0.5318685173988342, 2.0381877422332764, 0.33716848492622375], [-0.5744379758834839, -0.005252655595541, 1.7914447784423828, -0.27126064896583557], [-0.5965532064437866, -1.8395336866378784, 0.9394988417625427, 0.33245497941970825]]","[[14.26134204864502, 6.246121883392334, 16.684396743774414, 13.413249015808105], [18.52135467529297, -4.069742202758789, -8.969866752624512, -4.116239547729492], [-15.628847122192383, 1.7585363388061523, -7.5017409324646, 14.045808792114258]]","[[22.479833602905273, 10.961353302001953, -37.169498443603516, -3.8920676708221436], [22.053361892700195, 6.5353474617004395, 16.050573348999023, 1.3947471380233765], [53.140769958496094, 9.212106704711914, -23.101652145385742, 16.18545913696289]]",
|
||||||
|
6657209,"[[35.31223678588867, -22.971651077270508, 39.2910270690918, 53.64673614501953], [4.348316192626953, -16.771831512451172, -43.288639068603516, 28.353843688964844], [-39.88884353637695, -15.335161209106445, 31.237241744995117, 79.95108795166016]]","[[65.86861419677734, -66.7222671508789, -65.81453704833984, -53.20375442504883], [-21.004255294799805, -36.96867370605469, -49.629032135009766, 4.206972122192383], [-3.789904832839966, 53.388763427734375, -9.336297035217285, 59.61064147949219]]","[[-23.81308937072754, 17.60015106201172, -3.6998114585876465, -44.74623107910156], [39.35554122924805, 21.279781341552734, 25.772464752197266, -23.060672760009766], [23.75358772277832, 24.503376007080078, -16.24953842163086, -52.87109375]]",
|
||||||
|
6664919,"[[-4.587403774261475, 21.389644622802734, -0.8058542013168335, -15.177589416503906], [-30.81719207763672, 16.282472610473633, -30.25719451904297, -4.179492473602295], [-25.774181365966797, 2.7025911808013916, -15.140970230102539, 35.96717071533203]]","[[7.414963722229004, -73.7457275390625, -23.34357261657715, -59.4050178527832], [-37.04792022705078, 39.79411697387695, -28.534637451171875, -9.650754928588867], [-51.72665786743164, 5.047584533691406, 49.07307815551758, -10.990396499633789]]","[[-68.04022979736328, -18.380733489990234, -48.01260757446289, 4.971027851104736], [69.10401153564453, -24.525249481201172, -13.450369834899902, -31.265018463134766], [-17.114112854003906, 88.62091827392578, -49.525413513183594, -44.12730026245117]]",
|
||||||
|
6665234,"[[-58.47925567626953, -6.0381927490234375, 59.27907943725586, 3.1732094287872314], [18.26332664489746, -7.849457263946533, -13.465660095214844, 24.281553268432617], [50.79582214355469, -76.18656158447266, 37.01697540283203, 11.83233642578125]]","[[-19.288162231445312, -10.204944610595703, -34.52219009399414, 6.608401775360107], [6.105591297149658, 17.177522659301758, 41.61949920654297, 59.6090202331543], [-24.104351043701172, -3.944885015487671, 21.40576934814453, 4.275631427764893]]","[[9.385692596435547, 16.75665855407715, -22.38917350769043, 16.042783737182617], [25.004865646362305, -13.535341262817383, -1.943082571029663, -46.629024505615234], [-30.355073928833008, 7.768237590789795, -38.768218994140625, 10.225730895996094]]",
|
||||||
|
6667584,"[[-19.96772003173828, 78.34546661376953, 75.89759826660156, -2.6900768280029297], [-29.359573364257812, 36.52899932861328, -12.29699993133545, -49.58512496948242], [31.139745712280273, -32.53242111206055, -29.126243591308594, 47.14161682128906]]","[[7.564728736877441, -26.98961067199707, 14.280712127685547, -15.75666332244873], [16.308378219604492, -22.065216064453125, -4.871706008911133, -17.236865997314453], [34.73072814941406, 17.954309463500977, -9.708355903625488, -5.783907413482666]]","[[-37.74998474121094, 4.323561191558838, -15.50539779663086, 11.531072616577148], [42.675811767578125, -22.62323760986328, 16.255107879638672, 39.860836029052734], [-47.075279235839844, -16.6245174407959, 25.671066284179688, 0.7380026578903198]]",
|
||||||
|
6673385,"[[-36.83106231689453, 7.701159954071045, 68.32613372802734, 27.919288635253906], [-4.696883678436279, -3.020927906036377, -58.74855422973633, 9.028343200683594], [-35.0689811706543, 1.2082970142364502, 34.020774841308594, 1.3042429685592651]]","[[-16.659854888916016, 4.926011562347412, -10.322761535644531, 2.21010422706604], [-2.2169840335845947, -29.646329879760742, 3.1019740104675293, -26.244598388671875], [9.877544403076172, 14.357820510864258, -9.583166122436523, -2.8210270404815674]]","[[-23.14494514465332, -20.618330001831055, 21.521303176879883, 22.030481338500977], [-8.146241188049316, 40.73670196533203, -14.541313171386719, 36.61677169799805], [-4.874123573303223, 19.764659881591797, 22.12094497680664, 19.16327667236328]]",
|
||||||
|
6674521,"[[-25.607240676879883, 0.8884523510932922, 0.6508675217628479, -55.56379699707031], [-13.399002075195312, -16.82448387145996, 38.20143508911133, 24.327342987060547], [15.179723739624023, 0.8729835152626038, -21.22319793701172, -12.520669937133789]]","[[-1.71498441696167, 5.201033592224121, -11.60346508026123, -5.809443950653076], [5.660852909088135, 3.109971284866333, 8.395805358886719, -13.36353874206543], [8.876840591430664, -4.089223861694336, -5.3249406814575195, -0.06105200946331024]]","[[1.4979137182235718, -1.590577244758606, 0.5862289071083069, -0.25769785046577454], [-2.033968687057495, -2.840846300125122, -0.10808251053094864, -0.877410352230072], [-1.72908616065979, -2.7925174236297607, -0.9345141053199768, -1.441590666770935]]",
|
||||||
|
6676367,"[[69.78858947753906, 53.50902557373047, 63.16095733642578, -34.84336471557617], [28.3581485748291, 8.022404670715332, -29.504745483398438, 47.268455505371094], [19.14493751525879, 24.85756492614746, -53.912784576416016, -74.6889877319336]]","[[18.038612365722656, -22.782411575317383, -23.893470764160156, -1.615665078163147], [-15.572731018066406, 6.45712947845459, -22.30083465576172, 43.831302642822266], [14.26174259185791, -11.669622421264648, -0.7779999375343323, 23.53053855895996]]","[[-11.959007263183594, 4.956019401550293, -7.027940273284912, 0.33905893564224243], [-2.8083086013793945, 3.9247255325317383, -6.848259449005127, -15.556193351745605], [-59.83305358886719, -8.178778648376465, 12.3436861038208, -59.3747673034668]]",
|
||||||
|
6679119,"[[-101.50489044189453, 36.98146057128906, 1.1366711854934692, 63.51799774169922], [-46.073787689208984, 46.47813034057617, -104.55518341064453, 22.0133113861084], [53.618961334228516, 13.772979736328125, -20.572906494140625, 70.72864532470703]]","[[-17.259140014648438, -40.22257614135742, -83.89826965332031, 51.26460266113281], [-75.39704895019531, -31.925426483154297, 59.456993103027344, 45.64452362060547], [69.73789978027344, -66.56124877929688, -11.179588317871094, 32.963687896728516]]","[[-3.6455090045928955, -0.9025506377220154, -0.38173770904541016, -0.06332780420780182], [0.3781833350658417, 0.2325819581747055, -1.8777588605880737, -2.328604221343994], [-3.388141632080078, 3.184062957763672, 1.2819944620132446, -1.5070877075195312]]",
|
||||||
|
6679413,"[[-4.243306636810303, -7.72474479675293, 9.24826717376709, 0.6170316934585571], [6.635672092437744, -5.348725318908691, -9.360831260681152, -0.630814790725708], [7.321711540222168, 9.013751983642578, -3.012909412384033, -1.1782095432281494]]","[[37.10626983642578, -16.38568687438965, 28.856245040893555, 18.558990478515625], [34.18963623046875, -33.233551025390625, -18.14596939086914, -1.2347465753555298], [60.65745162963867, -18.7542781829834, -27.731416702270508, -7.704121112823486]]","[[-0.621583878993988, 5.9794158935546875, 6.679257392883301, -6.4304022789001465], [-14.296626091003418, -1.3633551597595215, 14.830915451049805, 5.049946308135986], [-7.077603340148926, 1.2459709644317627, -8.561247825622559, 26.62421417236328]]",
|
||||||
|
6684315,"[[-10.12170124053955, -3.4088053703308105, -4.571578502655029, -18.48915672302246], [11.414365768432617, -14.282251358032227, -3.3087334632873535, 5.553377151489258], [3.9793548583984375, 10.295501708984375, -13.825559616088867, 17.952943801879883]]","[[-28.288414001464844, -2.9670181274414062, -47.25498580932617, -24.528217315673828], [-29.686786651611328, -26.871307373046875, 9.226241111755371, -18.19982147216797], [-56.61760330200195, 3.2128732204437256, -17.26993179321289, 2.0199966430664062]]","[[5.996661186218262, 12.5982084274292, 2.2733395099639893, -3.376871347427368], [-1.1309608221054077, -7.995626926422119, 1.557140827178955, 2.26686692237854], [-0.9957780838012695, 8.70053482055664, -0.016650710254907608, -3.3884832859039307]]",
|
||||||
|
6684666,"[[50.06999969482422, -49.15105438232422, 3.551252841949463, -0.5807281732559204], [-40.77178955078125, -50.042633056640625, 5.525672435760498, -14.99709415435791], [33.131587982177734, -57.343509674072266, 57.94015884399414, 44.270660400390625]]","[[18.365150451660156, 20.553979873657227, 42.92595672607422, -18.730113983154297], [-50.91423034667969, 21.813535690307617, 3.634154796600342, 4.894844055175781], [-34.97544479370117, -49.256649017333984, -22.60733413696289, -22.322555541992188]]","[[-9.029924392700195, 15.004953384399414, 17.38901138305664, 7.252900123596191], [-18.01449203491211, 5.875588417053223, 5.970248222351074, -1.5177265405654907], [2.648449420928955, -4.20261812210083, -9.511507987976074, -12.106976509094238]]",
|
||||||
|
6685415,"[[35.113365173339844, -77.67644500732422, 50.20078659057617, 8.420661926269531], [-0.2904244661331177, 29.286212921142578, 3.7101621627807617, -40.164581298828125], [17.763580322265625, 18.826738357543945, -35.414276123046875, 1.2325835227966309]]","[[-18.88475227355957, -16.808897018432617, -34.2154541015625, -41.5155029296875], [1.1879208087921143, -34.65090560913086, -52.162071228027344, 35.442989349365234], [15.725982666015625, -15.781286239624023, -68.25137329101562, -37.75694274902344]]","[[-2.339825391769409, 0.8828161954879761, -0.016874248161911964, -4.491491794586182], [-1.3867239952087402, 1.5863149166107178, -1.9972578287124634, 2.9583418369293213], [-1.0680830478668213, -0.24776069819927216, -4.734472274780273, -1.0308974981307983]]",
|
||||||
|
6685730,"[[-25.302352905273438, -30.782779693603516, -27.899147033691406, -24.196697235107422], [-15.96635913848877, -8.835006713867188, -24.468276977539062, -8.283778190612793], [-44.42436981201172, 35.36659240722656, 13.494463920593262, 35.456520080566406]]","[[-1.8692127466201782, -64.31143951416016, 86.1197738647461, 62.821712493896484], [-69.96611022949219, -21.99472427368164, -10.378890991210938, 32.68306350708008], [55.907291412353516, -90.03367614746094, 2.1420021057128906, -67.6994857788086]]","[[-12.555275917053223, 18.81964874267578, -5.327393531799316, 49.367645263671875], [29.352062225341797, -30.749753952026367, -21.662687301635742, -8.036517143249512], [-53.613895416259766, 6.692556381225586, -0.9449663162231445, 24.52071762084961]]",
|
||||||
|
6687280,"[[-3.5868778228759766, -2.3832733631134033, 0.589530348777771, 11.288500785827637], [-10.782378196716309, -2.957477569580078, 7.451799392700195, 4.292666435241699], [7.849056720733643, -14.427438735961914, -11.46081829071045, 5.859859466552734]]","[[-48.13723373413086, -135.08383178710938, -40.40618133544922, 35.46659851074219], [-6.033477306365967, 40.350830078125, -42.84685516357422, -52.211002349853516], [-111.6717529296875, 9.398500442504883, -43.886268615722656, -82.76837921142578]]","[[33.583717346191406, 5.88064432144165, 6.754786014556885, -32.132755279541016], [-17.764244079589844, 82.55520629882812, -30.020442962646484, 24.64034652709961], [-89.43638610839844, -5.825024604797363, 56.58000564575195, -11.378416061401367]]",
|
||||||
|
6687805,"[[-41.159664154052734, -19.464439392089844, 25.74068260192871, 1.872757911682129], [-48.20577621459961, -22.682870864868164, -12.677288055419922, 10.750849723815918], [26.64324188232422, -59.10613250732422, 21.565441131591797, -40.196861267089844]]","[[18.406984329223633, -31.138717651367188, 25.356599807739258, 12.595141410827637], [-33.44847869873047, -21.521642684936523, 25.265287399291992, -23.24741554260254], [-29.790891647338867, 16.76827621459961, -24.88990020751953, 31.059619903564453]]","[[22.057342529296875, -39.534358978271484, 61.31902313232422, -44.508331298828125], [2.0680832862854004, 19.670103073120117, -17.426483154296875, -8.790121078491211], [10.202092170715332, -1.5388749837875366, 8.696026802062988, -5.0123610496521]]",
|
||||||
|
6687869,"[[0.07490905374288559, -1.273012638092041, 0.6518950462341309, -0.6229028701782227], [-0.36243799328804016, -2.1689934730529785, -1.3617095947265625, 2.652318000793457], [1.0193486213684082, -3.5647265911102295, 5.747159004211426, -1.6326216459274292]]","[[-7.607456207275391, 28.280895233154297, -18.318817138671875, 24.297714233398438], [11.687575340270996, -2.194053888320923, -3.6463708877563477, -9.412092208862305], [20.877317428588867, -12.823661804199219, 25.226224899291992, -1.2035841941833496]]","[[2.239046096801758, 2.992509365081787, -3.4312944412231445, 2.049949884414673], [-6.315691947937012, 0.6009021997451782, -1.2477636337280273, -1.7523036003112793], [1.2447832822799683, -5.492430210113525, 3.384784698486328, 1.0218923091888428]]",
|
||||||
|
6689012,"[[10.561138153076172, 24.293989181518555, 8.856136322021484, -8.001155853271484], [-19.200115203857422, -4.852842330932617, 14.093475341796875, 39.930023193359375], [-22.35072898864746, 33.056114196777344, -6.0201215744018555, -39.17753601074219]]","[[-16.157075881958008, -1.0930131673812866, 5.450170040130615, -40.11588668823242], [5.857770919799805, 23.14008331298828, -15.793992042541504, -5.903223037719727], [-13.125252723693848, 11.460956573486328, 117.74522399902344, -92.62993621826172]]","[[-6.055703639984131, -12.719776153564453, 8.75866985321045, 0.42572861909866333], [2.1340320110321045, -13.77920150756836, 14.589826583862305, 2.4945802688598633], [14.152591705322266, 13.32497501373291, -2.875643014907837, -6.7089128494262695]]",
|
||||||
|
6690072,"[[-1.419852614402771, 0.2286374419927597, 0.3345853388309479, 0.2729721665382385], [0.6338568925857544, -0.8546611666679382, 0.869610607624054, -2.08027982711792], [-1.1930214166641235, -0.2104170322418213, -0.9776290655136108, -0.7793132066726685]]","[[33.953224182128906, -25.176177978515625, 11.610986709594727, 27.586698532104492], [25.29509925842285, 8.522324562072754, 4.060436248779297, 37.154415130615234], [-37.96092987060547, 53.34375, 49.52286148071289, 62.254817962646484]]","[[104.4286117553711, 68.14720916748047, 24.368501663208008, 44.12657928466797], [-13.939213752746582, 7.622133255004883, 4.232577323913574, 22.11945343017578], [-9.448180198669434, 40.441768646240234, -8.889688491821289, -3.6762876510620117]]",
|
||||||
|
6691144,"[[22.178943634033203, 12.019502639770508, 31.294391632080078, 46.64274978637695], [44.24049758911133, 19.075437545776367, -3.5804359912872314, -37.055137634277344], [-79.22559356689453, 63.357879638671875, 19.648544311523438, 70.82246398925781]]","[[-0.05032595247030258, 1.217042088508606, 0.5083667635917664, 0.5189406871795654], [-0.6396045684814453, -0.5928763151168823, -0.7169412970542908, -0.1308005154132843], [1.3062458038330078, 1.1942483186721802, 1.5429742336273193, -1.3320108652114868]]","[[-0.9213723540306091, 2.8213400840759277, -1.3256995677947998, -1.1574915647506714], [0.6758412718772888, -0.9888588786125183, -0.7084240317344666, 0.0021383522544056177], [1.3449786901474, -0.31319916248321533, -0.27399250864982605, -0.09528028964996338]]",
|
||||||
|
6691398,"[[3.93902850151062, -18.370189666748047, 20.255130767822266, 12.914362907409668], [-9.456490516662598, 40.97371292114258, 40.47419738769531, -5.51539945602417], [27.83440589904785, -25.091535568237305, 27.184425354003906, -44.957481384277344]]","[[47.26085662841797, -0.5326224565505981, -32.40896987915039, 33.84670639038086], [28.53173065185547, -2.5092074871063232, 35.7825813293457, -13.36374282836914], [-78.28602600097656, 70.24180603027344, -28.044546127319336, 30.846582412719727]]","[[23.940338134765625, -29.73199462890625, -36.24672317504883, 38.23209762573242], [-9.304244995117188, 8.964479446411133, -10.025215148925781, -5.452454566955566], [34.69584274291992, -28.6850643157959, 51.733543395996094, 5.806751251220703]]",
|
||||||
|
6694392,"[[-32.651275634765625, 6.973701477050781, -19.875314712524414, 3.2574033737182617], [52.17797088623047, 17.92009735107422, -2.8579277992248535, 18.69631576538086], [4.248002529144287, -1.815674901008606, -18.37828826904297, -46.7224235534668]]","[[-12.207379341125488, -31.712554931640625, 28.55495262145996, 15.864582061767578], [44.41409683227539, -52.78776550292969, -31.630704879760742, -8.993188858032227], [9.403763771057129, -15.210376739501953, -2.9357264041900635, 7.387798309326172]]","[[12.651803016662598, -1.8179872035980225, -8.39204216003418, -10.712897300720215], [28.703163146972656, 17.471147537231445, -16.652137756347656, -19.11032485961914], [-13.872947692871094, 4.021103858947754, 2.7464599609375, -9.595636367797852]]",
|
||||||
|
6695534,"[[-13.405084609985352, 9.418004035949707, 19.678552627563477, -3.701709032058716], [22.82848358154297, 24.53850746154785, -11.5900297164917, 1.3250523805618286], [-19.844371795654297, -3.264390707015991, 1.127580165863037, -31.00078010559082]]","[[-4.07348108291626, 39.030860900878906, 51.24266052246094, 6.860040187835693], [-30.623088836669922, 36.461055755615234, 0.7204872965812683, -15.802491188049316], [-57.89153289794922, 44.76567459106445, -42.03072738647461, 9.589776039123535]]","[[-3.89924693107605, -54.185546875, -29.025619506835938, -23.89274024963379], [11.800339698791504, 6.643435001373291, -28.981651306152344, -24.919710159301758], [-2.8200690746307373, 24.84758186340332, -3.097881317138672, -16.181795120239258]]",
|
||||||
|
6697080,"[[-1.2122610807418823, -58.39776611328125, -1.0357400178909302, -2.3798153400421143], [-0.0755167007446289, 103.1572265625, 64.51415252685547, -16.170059204101562], [29.636972427368164, 34.911067962646484, -45.17095184326172, 32.64284896850586]]","[[8.11724853515625, 40.43547821044922, -5.311676025390625, -17.611391067504883], [-54.05378341674805, 2.7808048725128174, 43.850887298583984, -14.37118148803711], [3.0494017601013184, -36.97222900390625, 17.106531143188477, 0.831373929977417]]","[[-13.555829048156738, -6.29565954208374, 10.418721199035645, 56.299983978271484], [35.56388854980469, -16.308576583862305, 11.416728019714355, 8.350481033325195], [-58.433876037597656, -4.182812690734863, 37.23177719116211, 10.30872631072998]]",
|
||||||
|
6698610,"[[-6.041436195373535, 9.156213760375977, 1.798213005065918, -11.413951873779297], [-9.282434463500977, -6.500437259674072, 11.692222595214844, -0.07979172468185425], [-10.110106468200684, -1.6897872686386108, -6.469929218292236, 7.471987724304199]]","[[-36.39547348022461, 82.41608428955078, 50.620750427246094, -42.36143112182617], [22.46283721923828, 28.61631965637207, 2.2778117656707764, -10.755844116210938], [-42.1403694152832, 73.60430145263672, 8.922992706298828, 23.86935806274414]]","[[0.27653270959854126, 4.387450695037842, 25.69333839416504, -21.38128662109375], [41.85139083862305, -5.535375118255615, -2.226491928100586, -15.576900482177734], [-4.816816806793213, 5.0897955894470215, -9.330107688903809, 15.888705253601074]]",
|
||||||
|
6699778,"[[-11.416746139526367, -6.792178630828857, -7.349616050720215, -17.06553077697754], [-22.967361450195312, -5.880982875823975, 23.729190826416016, 40.73758316040039], [10.738740921020508, 17.928733825683594, -5.958589553833008, -20.662246704101562]]","[[10.528449058532715, 5.952178955078125, -15.57581901550293, 3.7095937728881836], [-45.46097946166992, 18.745311737060547, -7.266114711761475, 17.222610473632812], [-6.2384467124938965, -5.815711975097656, 0.8986830711364746, -16.490060806274414]]","[[23.08553123474121, 24.72433853149414, -4.271855354309082, -7.6843366622924805], [3.9064691066741943, -2.9235575199127197, -2.4568283557891846, -11.444953918457031], [-2.5821399688720703, -16.44976806640625, 6.171452522277832, 11.762450218200684]]",
|
||||||
|
6699788,"[[12.881512641906738, -26.03125, -31.379695892333984, -0.18674679100513458], [-17.41887855529785, -15.735973358154297, -58.854496002197266, -8.337396621704102], [-51.53794860839844, -14.746176719665527, -56.51786422729492, 29.540775299072266]]","[[12.720685958862305, 5.674424171447754, 17.607324600219727, -5.377486228942871], [-61.59711456298828, -0.8513314127922058, 22.417675018310547, -11.627851486206055], [-5.549165725708008, -6.901993274688721, -21.57571029663086, 7.405230522155762]]","[[-46.049503326416016, 43.273468017578125, 1.1638909578323364, 16.160280227661133], [62.02693557739258, 58.560646057128906, -87.04951477050781, -86.58930969238281], [-28.130151748657227, -34.52029037475586, 14.721302032470703, 26.94708824157715]]",
|
||||||
|
6702876,"[[52.45893859863281, -100.3917007446289, -33.253238677978516, -29.724864959716797], [21.84295654296875, -13.550008773803711, -68.54496765136719, 49.517982482910156], [62.37989807128906, 3.997654914855957, 44.85454559326172, -36.87840270996094]]","[[32.890541076660156, -15.513359069824219, 20.226282119750977, 20.90232276916504], [-16.48346710205078, 4.854818820953369, 11.418033599853516, -0.8921002745628357], [-25.857093811035156, 6.862802982330322, -7.632472515106201, -3.579554319381714]]","[[-60.80701446533203, -5.687557220458984, 15.815590858459473, 14.843527793884277], [-45.50221252441406, 20.321155548095703, -51.37971115112305, 15.229931831359863], [107.47465515136719, -24.3437442779541, -16.632457733154297, -16.89650535583496]]",
|
||||||
|
6702928,"[[18.301334381103516, -78.94622802734375, -6.391412734985352, 17.14784812927246], [43.93109893798828, 22.418628692626953, -55.411903381347656, -62.78773498535156], [-10.213981628417969, -45.5717658996582, -47.030120849609375, -38.92612838745117]]","[[18.429698944091797, -13.791690826416016, -31.753252029418945, 7.937379360198975], [-9.21422290802002, -16.363967895507812, 3.4244227409362793, 11.589090347290039], [0.4459109306335449, -42.66963577270508, -26.912336349487305, -45.24558639526367]]","[[3.989264965057373, -10.394329071044922, 2.9203333854675293, -15.809683799743652], [0.13381671905517578, 6.432409763336182, -13.255852699279785, 5.919764518737793], [0.5510143637657166, -4.444141387939453, -3.3461360931396484, -8.943599700927734]]",
|
||||||
|
6705228,"[[-28.573619842529297, 40.02886199951172, -15.857441902160645, 40.771156311035156], [-15.363801002502441, -22.323780059814453, 28.274812698364258, 22.44999885559082], [5.446697235107422, 10.143810272216797, 37.238800048828125, 15.129122734069824]]","[[18.833009719848633, 45.32041931152344, -24.788484573364258, -7.510664463043213], [40.43198013305664, 4.926146984100342, -3.4637563228607178, 0.8702676892280579], [28.37637710571289, 17.996984481811523, -9.186692237854004, -20.234752655029297]]","[[6.580479145050049, 1.9564368724822998, -7.401310920715332, -8.521329879760742], [-3.6674721240997314, -5.470401763916016, 11.234831809997559, -4.046061038970947], [4.957516193389893, -7.476752758026123, 2.370039224624634, -5.951333999633789]]",
|
|
1146
Lab1_2/Lab1&2_Transformers-base.ipynb
Normal file
1146
Lab1_2/Lab1&2_Transformers-base.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
93
Lab1_2/Lab1&2_Transformers.ipynb
Normal file
93
Lab1_2/Lab1&2_Transformers.ipynb
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "Cv-9Vzunb_tf"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Import Necessary Library"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {
|
||||||
|
"id": "4f-K54nHb-Uq"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"import torch.utils.data as data\n",
|
||||||
|
"import math\n",
|
||||||
|
"import os\n",
|
||||||
|
"import urllib.request\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from functools import partial\n",
|
||||||
|
"from urllib.error import HTTPError\n",
|
||||||
|
"from datetime import datetime"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {
|
||||||
|
"id": "XCv8_IzSdut4"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def scaled_dot_product(q, k, v, mask=None):\n",
|
||||||
|
" # implemented by the student, you can ignore the mask implementation currently\n",
|
||||||
|
" # just assignment all the mask is on\n",
|
||||||
|
"\n",
|
||||||
|
" shape_len = len(k.shape)\n",
|
||||||
|
"\n",
|
||||||
|
" transpose = k.mT\n",
|
||||||
|
" d = k.shape[-1]\n",
|
||||||
|
"\n",
|
||||||
|
" score_scale = torch.matmul(q, transpose)/math.sqrt(d)\n",
|
||||||
|
"\n",
|
||||||
|
" attention_weight = torch.nn.functional.softmax(score_scale, 1)\n",
|
||||||
|
"\n",
|
||||||
|
" output = torch.matmul(attention_weight, v)\n",
|
||||||
|
"\n",
|
||||||
|
" return output, attention_weight"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": [],
|
||||||
|
"toc_visible": true
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
7058
Lab3/Week3_Autoencoder+MAE - Copy.ipynb
Normal file
7058
Lab3/Week3_Autoencoder+MAE - Copy.ipynb
Normal file
File diff suppressed because one or more lines are too long
562
Lab3/Week3_Autoencoder+MAE - Copy.py
Normal file
562
Lab3/Week3_Autoencoder+MAE - Copy.py
Normal file
@ -0,0 +1,562 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
# # Introduction & Import Necessary Setup
|
||||||
|
# In this labsheet, we'll delve into the fascinating world of autoencoders (AEs), a type of neural network renowned for its ability to compress and reconstruct data. Autoencoders work by first encoding input data, such as images, into a compact feature vector through an encoder network. This process effectively distills the essence of the data into a smaller, more manageable form. The feature vector, often referred to as the "bottleneck," plays a crucial role in this compression process, allowing us to represent the input data with fewer features.
|
||||||
|
#
|
||||||
|
# Following compression, a second neural network, known as the decoder, takes over to reconstruct the original data from the compressed feature vector. This remarkable ability to compress and then reconstruct data makes autoencoders extremely valuable in various applications, including data compression and image comparison at a more meaningful level than mere pixel-by-pixel analysis.
|
||||||
|
#
|
||||||
|
# Moreover, our exploration will not stop at the autoencoder framework itself. We will also introduce the concept of "deconvolution" (also known as transposed convolution), a powerful operator used to enlarge feature maps in both height and width dimensions. Deconvolution networks are indispensable in scenarios where we begin with a compact feature vector and aim to generate a full-sized image. This technique is pivotal in various advanced neural network applications, such as Variational Autoencoders (VAEs), Generative Adversarial Networks (GANs), and super-resolution.
|
||||||
|
#
|
||||||
|
# To kick things off, we'll start by importing our standard libraries, setting the stage for our deep dive into the inner workings and applications of autoencoders.
|
||||||
|
|
||||||
|
# In[1]:
|
||||||
|
|
||||||
|
|
||||||
|
## Standard libraries
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
## Imports for plotting
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
get_ipython().run_line_magic('matplotlib', 'inline')
|
||||||
|
from IPython.display import set_matplotlib_formats
|
||||||
|
set_matplotlib_formats('svg', 'pdf') # For export
|
||||||
|
from matplotlib.colors import to_rgb
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.rcParams['lines.linewidth'] = 2.0
|
||||||
|
## Progress bar
|
||||||
|
from tqdm.notebook import tqdm
|
||||||
|
|
||||||
|
## PyTorch
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.data as data
|
||||||
|
import torch.optim as optim
|
||||||
|
# Torchvision
|
||||||
|
import torchvision
|
||||||
|
from torchvision.datasets import CIFAR10
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
DATASET_PATH = "dataset"
|
||||||
|
|
||||||
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
print("Device:", device)
|
||||||
|
|
||||||
|
|
||||||
|
# # Download and setup the dataset
|
||||||
|
# In this labsheet, our focus shifts to the CIFAR10 dataset, a collection known for its rich, colored images. Each image within CIFAR10 is equipped with 3 color channels and boasts a resolution of 32x32 pixels. This characteristic is particularly advantageous when working with autoencoders, as they are not bound by the constraints of probabilistic image modeling.
|
||||||
|
#
|
||||||
|
# Should you already have the CIFAR10 dataset downloaded in a different directory, it's important to adjust the DATASET_PATH variable accordingly. This step ensures you avoid unnecessary additional downloads, streamlining your workflow and allowing you to dive into the practical exercises more swiftly.
|
||||||
|
|
||||||
|
# In[105]:
|
||||||
|
|
||||||
|
|
||||||
|
# Transformations applied on each image => only make them a tensor
|
||||||
|
transform = transforms.Compose([transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,),(0.5,))])
|
||||||
|
|
||||||
|
# Loading the training dataset. We need to split it into a training and validation part
|
||||||
|
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)
|
||||||
|
train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])
|
||||||
|
|
||||||
|
# Loading the test set
|
||||||
|
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)
|
||||||
|
|
||||||
|
# We define a set of data loaders that we can use for various purposes later.
|
||||||
|
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
|
||||||
|
val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)
|
||||||
|
test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)
|
||||||
|
|
||||||
|
def get_train_images(num):
|
||||||
|
return torch.stack([train_dataset[i][0] for i in range(num)], dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
# # Building the autoencoder
|
||||||
|
#
|
||||||
|
# In general, an autoencoder consists of an **encoder** that maps the input $x$ to a lower-dimensional feature vector $z$, and a **decoder** that reconstructs the input $\hat{x}$ from $z$. We train the model by comparing $x$ to $\hat{x}$ and optimizing the parameters to increase the similarity between $x$ and $\hat{x}$. See below for a small illustration of the autoencoder framework.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# ![img](https://raw.githubusercontent.com/hqsiswiliam/COM3025_Torch/main/autoencoder.png)
|
||||||
|
|
||||||
|
#
|
||||||
|
# For an educational purpose revision in markdown format, the text could be enhanced as follows:
|
||||||
|
#
|
||||||
|
# To kick off our exploration, we initiate with the construction of the encoder. This component is fundamentally a deep convolutional network tailored for progressively diminishing the image's dimensions. This diminution is achieved through the use of strided convolutions, which methodically reduce the image's size layer by layer. Following the thrice-executed downscaling process, we transition the architecture from convolutional layers to a flattened feature representation. This is achieved by flattening the spatial features into a single vector, which is then processed through several linear layers. As a result, we obtain the latent representation, denoted as
|
||||||
|
# $z$, encapsulating the compressed essence of the input image. The size of this latent vector, $d$, is adjustable, providing flexibility in the encoding capacity of our network.
|
||||||
|
|
||||||
|
# In[59]:
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_input_channels : int,
|
||||||
|
base_channel_size : int,
|
||||||
|
latent_dim : int,
|
||||||
|
act_fn : object = nn.GELU):
|
||||||
|
"""
|
||||||
|
Inputs:
|
||||||
|
- num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
|
||||||
|
- base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
|
||||||
|
- latent_dim : Dimensionality of latent representation z
|
||||||
|
- act_fn : Activation function used throughout the encoder network
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
c_hid = base_channel_size
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
|
||||||
|
act_fn(),
|
||||||
|
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
|
||||||
|
act_fn(),
|
||||||
|
nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
|
||||||
|
act_fn(),
|
||||||
|
nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
|
||||||
|
act_fn(),
|
||||||
|
nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
|
||||||
|
act_fn(),
|
||||||
|
nn.Flatten(), # Image grid to single feature vector
|
||||||
|
nn.Linear(2*16*c_hid, latent_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.flatten = nn.Sequential(
|
||||||
|
# nn.Flatten(), # Image grid to single feature vector
|
||||||
|
# nn.Linear(2*16*c_hid, latent_dim)
|
||||||
|
# )
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x = self.net(x)
|
||||||
|
|
||||||
|
# print(x.shape)
|
||||||
|
|
||||||
|
# return self.flatten(x)
|
||||||
|
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
# # Task1
|
||||||
|
# Now Complete the decoder implementation
|
||||||
|
|
||||||
|
# In[133]:
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_input_channels : int,
|
||||||
|
base_channel_size : int,
|
||||||
|
latent_dim : int,
|
||||||
|
act_fn : object = nn.GELU):
|
||||||
|
"""
|
||||||
|
Inputs:
|
||||||
|
- num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
|
||||||
|
- base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
|
||||||
|
- latent_dim : Dimensionality of latent representation z
|
||||||
|
- act_fn : Activation function used throughout the decoder network
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
c_hid = base_channel_size
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(latent_dim, 2*16*c_hid),
|
||||||
|
act_fn(),
|
||||||
|
nn.Unflatten(1, (2*c_hid, 4, 4)),
|
||||||
|
nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2, output_padding=1), # 8x8 <= 4x4
|
||||||
|
act_fn(),
|
||||||
|
nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
|
||||||
|
act_fn(),
|
||||||
|
nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, padding=1, stride=2, output_padding=1), # 16x16 <= 8x8
|
||||||
|
act_fn(),
|
||||||
|
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
|
||||||
|
act_fn(),
|
||||||
|
nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, padding=1, stride=2, output_padding=1), # 32x32 <= 16x16
|
||||||
|
nn.Tanh(),
|
||||||
|
# nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
# You code goes here.
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
# You code goes here.
|
||||||
|
|
||||||
|
|
||||||
|
# # Combining Encoder and Decoder
|
||||||
|
# ## Loss Function: Mean Squared Error (MSE)
|
||||||
|
#
|
||||||
|
# For our loss function, we opt for the Mean Squared Error (MSE). MSE is particularly effective in emphasizing the significance of accurately predicting pixel values that are substantially misestimated by the network. For instance, a minor deviation, such as predicting 127 instead of 128, is deemed less critical. However, larger discrepancies, like confusing a pixel value of 0 with 128, are considered more severe and detrimental to the reconstruction quality.
|
||||||
|
#
|
||||||
|
# Unlike Variational Autoencoders (VAEs) that predict the probability for each pixel value, we employ MSE as a straightforward distance measure. This approach significantly reduces the number of parameters, streamlining the training process. To enhance our understanding of the per-pixel performance, we calculate the summed squared error, averaged across the batch dimension. It's important to note that alternative aggregations (mean or sum) yield equivalent outcomes in terms of resulting parameters.
|
||||||
|
#
|
||||||
|
# ### Limitations of MSE
|
||||||
|
#
|
||||||
|
# Despite its advantages, MSE is not without drawbacks. Primarily, it tends to produce blurrier images, as it inherently removes small noise and high-frequency patterns, which contribute minimally to the overall error. To mitigate this and achieve more realistic reconstructions, integrating Generative Adversarial Networks (GANs) with autoencoders has proven effective. This hybrid approach is explored in various studies ([example 1](https://arxiv.org/abs/1704.02304), [example 2](https://arxiv.org/abs/1511.05644), and [slides](http://elarosca.net/slides/iccv_autoencoder_gans.pdf)).
|
||||||
|
#
|
||||||
|
# Furthermore, MSE may not always accurately reflect visual similarity between images. A case in point is when an autoencoder produces an image that is slightly shifted—despite the near-identical appearance, the MSE can significantly increase, showcasing a limitation in capturing true visual fidelity. A potential solution involves leveraging a pre-trained CNN to measure distance based on visual features extracted from lower layers, offering a more nuanced comparison than pixel-level MSE.
|
||||||
|
#
|
||||||
|
|
||||||
|
# In[134]:
|
||||||
|
|
||||||
|
|
||||||
|
class Autoencoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
base_channel_size: int,
|
||||||
|
latent_dim: int,
|
||||||
|
encoder_class : object = Encoder,
|
||||||
|
decoder_class : object = Decoder,
|
||||||
|
num_input_channels: int = 3,
|
||||||
|
width: int = 32,
|
||||||
|
height: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
# Creating encoder and decoder
|
||||||
|
self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
|
||||||
|
self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
|
||||||
|
# Example input array needed for visualizing the graph of the network
|
||||||
|
self.example_input_array = torch.zeros(2, num_input_channels, width, height)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z = self.encoder(x)
|
||||||
|
x_hat = self.decoder(z)
|
||||||
|
return x_hat
|
||||||
|
|
||||||
|
def _get_reconstruction_loss(self, batch):
|
||||||
|
x = batch # We do not need the labels
|
||||||
|
x_hat = self.forward(x)
|
||||||
|
loss = F.mse_loss(x, x_hat, reduction="none")
|
||||||
|
loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# # Utility code for comparing Images
|
||||||
|
|
||||||
|
# In[14]:
|
||||||
|
|
||||||
|
|
||||||
|
def compare_imgs(img1, img2, title_prefix=""):
|
||||||
|
# Calculate MSE loss between both images
|
||||||
|
loss = F.mse_loss(img1, img2, reduction="sum")
|
||||||
|
# Plot images for visual comparison
|
||||||
|
grid = torchvision.utils.make_grid(torch.stack([img1, img2], dim=0), nrow=2, normalize=True)
|
||||||
|
grid = grid.permute(1, 2, 0)
|
||||||
|
plt.figure(figsize=(4,2))
|
||||||
|
plt.title(f"{title_prefix} Loss: {loss.item():4.2f}")
|
||||||
|
plt.imshow(grid)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
# Load example image
|
||||||
|
img, _ = train_dataset[i]
|
||||||
|
img_mean = img.mean(dim=[1,2], keepdims=True)
|
||||||
|
|
||||||
|
# Shift image by one pixel
|
||||||
|
SHIFT = 1
|
||||||
|
img_shifted = torch.roll(img, shifts=SHIFT, dims=1)
|
||||||
|
img_shifted = torch.roll(img_shifted, shifts=SHIFT, dims=2)
|
||||||
|
img_shifted[:,:1,:] = img_mean
|
||||||
|
img_shifted[:,:,:1] = img_mean
|
||||||
|
compare_imgs(img, img_shifted, "Shifted -")
|
||||||
|
|
||||||
|
# Set half of the image to zero
|
||||||
|
img_masked = img.clone()
|
||||||
|
img_masked[:,:img_masked.shape[1]//2,:] = img_mean
|
||||||
|
compare_imgs(img, img_masked, "Masked -")
|
||||||
|
|
||||||
|
|
||||||
|
# # Task2
|
||||||
|
# Add training code to train the AutoEncoder
|
||||||
|
|
||||||
|
# In[2]:
|
||||||
|
|
||||||
|
|
||||||
|
# for batch in tqdm(train_loader, total=len(train_loader)):
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
model = Autoencoder(64, 128, ) # you code here
|
||||||
|
model.to(device)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # your code here
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # your code here, can use ReduceLROnPlateau
|
||||||
|
# Write training loop here
|
||||||
|
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
|
||||||
|
n_epoch = 40
|
||||||
|
model.train()
|
||||||
|
for epoch in range(n_epoch):
|
||||||
|
print(f"\nEpoch {epoch}:")
|
||||||
|
|
||||||
|
avg_loss = 0
|
||||||
|
|
||||||
|
for i, data in enumerate(train_loader):
|
||||||
|
inputs, _ = data
|
||||||
|
|
||||||
|
inputs = inputs.cuda()
|
||||||
|
|
||||||
|
loss = model._get_reconstruction_loss(inputs) #loss_fn(outputs, inputs)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
avg_loss += loss
|
||||||
|
|
||||||
|
print(f'\rBatch: {i}: Loss:{loss} avg_Loss: {avg_loss/(i + 1)} ', end='')
|
||||||
|
|
||||||
|
scheduler.step(loss)
|
||||||
|
|
||||||
|
|
||||||
|
# In[144]:
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_reconstructions(model, input_imgs):
|
||||||
|
# Reconstruct images
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
reconst_imgs = model(input_imgs.to(device))
|
||||||
|
reconst_imgs = reconst_imgs.cpu()
|
||||||
|
|
||||||
|
# Plotting
|
||||||
|
imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)
|
||||||
|
grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=True)
|
||||||
|
grid = grid.permute(1, 2, 0)
|
||||||
|
plt.figure(figsize=(7,4.5))
|
||||||
|
plt.title(f"Reconstructed from model")
|
||||||
|
plt.imshow(grid)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
input_imgs = get_train_images(6)
|
||||||
|
visualize_reconstructions(model, input_imgs)
|
||||||
|
|
||||||
|
|
||||||
|
# # Masked AutoEncoder
|
||||||
|
# The follow code are the demonstration of Masked Autoencoder implementation and visualization
|
||||||
|
|
||||||
|
# # Import Necessary Libraries
|
||||||
|
|
||||||
|
# In[4]:
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# check whether run in Colab
|
||||||
|
if 'google.colab' in sys.modules:
|
||||||
|
print('Running in Colab.')
|
||||||
|
get_ipython().system('pip3 install timm==0.4.5 # 0.3.2 does not work in Colab')
|
||||||
|
get_ipython().system('git clone https://github.com/facebookresearch/mae.git')
|
||||||
|
sys.path.append('./mae')
|
||||||
|
else:
|
||||||
|
sys.path.append('./mae')
|
||||||
|
import models_mae
|
||||||
|
|
||||||
|
|
||||||
|
# # Build up necessary utillities
|
||||||
|
|
||||||
|
# In[131]:
|
||||||
|
|
||||||
|
|
||||||
|
# define the utils
|
||||||
|
|
||||||
|
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
||||||
|
imagenet_std = np.array([0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
def show_image(image, title=''):
|
||||||
|
# image is [H, W, 3]
|
||||||
|
assert image.shape[2] == 3
|
||||||
|
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
|
||||||
|
plt.title(title, fontsize=16)
|
||||||
|
plt.axis('off')
|
||||||
|
return
|
||||||
|
|
||||||
|
def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
|
||||||
|
# build model
|
||||||
|
model = getattr(models_mae, arch)()
|
||||||
|
# load model
|
||||||
|
checkpoint = torch.load(chkpt_dir, map_location='cpu')
|
||||||
|
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
||||||
|
print(msg)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def run_one_image(img, model):
|
||||||
|
x = torch.tensor(img)
|
||||||
|
|
||||||
|
# make it a batch-like
|
||||||
|
x = x.unsqueeze(dim=0)
|
||||||
|
x = torch.einsum('nhwc->nchw', x)
|
||||||
|
|
||||||
|
# run MAE
|
||||||
|
loss, y, mask = model(x.float(), mask_ratio= 0.75)
|
||||||
|
y = model.unpatchify(y)
|
||||||
|
y = torch.einsum('nchw->nhwc', y).detach().cpu()
|
||||||
|
|
||||||
|
# visualize the mask
|
||||||
|
mask = mask.detach()
|
||||||
|
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)
|
||||||
|
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
|
||||||
|
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
||||||
|
|
||||||
|
x = torch.einsum('nchw->nhwc', x)
|
||||||
|
|
||||||
|
# masked image
|
||||||
|
im_masked = x * (1 - mask)
|
||||||
|
|
||||||
|
# MAE reconstruction pasted with visible patches
|
||||||
|
im_paste = x * (1 - mask) + y * mask
|
||||||
|
|
||||||
|
# make the plt figure larger
|
||||||
|
plt.rcParams['figure.figsize'] = [24, 24]
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 1)
|
||||||
|
show_image(x[0], "original")
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 2)
|
||||||
|
show_image(im_masked[0], "masked")
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 3)
|
||||||
|
show_image(y[0], "reconstruction")
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 4)
|
||||||
|
show_image(im_paste[0], "reconstruction + visible")
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# # Load one image
|
||||||
|
|
||||||
|
# In[189]:
|
||||||
|
|
||||||
|
|
||||||
|
# load an image
|
||||||
|
img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
|
||||||
|
# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851
|
||||||
|
img = Image.open(requests.get(img_url, stream=True).raw)
|
||||||
|
img = img.resize((224, 224))
|
||||||
|
img = np.array(img) / 255.
|
||||||
|
|
||||||
|
assert img.shape == (224, 224, 3)
|
||||||
|
|
||||||
|
# normalize by ImageNet mean and std
|
||||||
|
img = img - imagenet_mean
|
||||||
|
img = img / imagenet_std
|
||||||
|
|
||||||
|
plt.rcParams['figure.figsize'] = [5, 5]
|
||||||
|
show_image(torch.tensor(img))
|
||||||
|
|
||||||
|
|
||||||
|
# In[141]:
|
||||||
|
|
||||||
|
|
||||||
|
# Patch for numpy error
|
||||||
|
np.float = float
|
||||||
|
np.int = int #module 'numpy' has no attribute 'int'
|
||||||
|
np.object = object #module 'numpy' has no attribute 'object'
|
||||||
|
np.bool = bool #module 'numpy' has no attribute 'bool'
|
||||||
|
# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)
|
||||||
|
|
||||||
|
# download checkpoint if not exist
|
||||||
|
get_ipython().system('wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth')
|
||||||
|
|
||||||
|
chkpt_dir = 'mae_visualize_vit_large.pth'
|
||||||
|
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
|
||||||
|
print('Model loaded.')
|
||||||
|
|
||||||
|
|
||||||
|
mine_img = Image.open('./st2/6644818.png', formats=('PNG',)).convert('RGB')# Image.open(requests.get(img_url, stream=True).raw)
|
||||||
|
|
||||||
|
# mine_img.show()
|
||||||
|
mine_img = mine_img.resize((224, 224))
|
||||||
|
|
||||||
|
mine_img = np.array(mine_img) / 255.
|
||||||
|
|
||||||
|
# print(mine_img.shape, mine_img[0][0])
|
||||||
|
|
||||||
|
assert mine_img.shape == (224, 224, 3)
|
||||||
|
|
||||||
|
# normalize by ImageNet mean and std
|
||||||
|
mine_img = mine_img - imagenet_mean
|
||||||
|
mine_img = mine_img / imagenet_std
|
||||||
|
|
||||||
|
plt.rcParams['figure.figsize'] = [5, 5]
|
||||||
|
show_image(torch.tensor(mine_img))
|
||||||
|
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
d = pd.read_csv('st2/6644818/shuffle_info.csv', header=None)
|
||||||
|
|
||||||
|
ids_keep = torch.Tensor(eval(d.loc[0][1])).type(torch.int64)
|
||||||
|
ids_restore = torch.Tensor(eval(d.loc[1][1])).type(torch.int64)
|
||||||
|
|
||||||
|
def masking(self, x):
|
||||||
|
N, L, D = x.shape # batch, length, dim
|
||||||
|
return torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # Creates the masked images
|
||||||
|
|
||||||
|
def forward_encoder(self, x):
|
||||||
|
# embed patches
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
# add pos embed w/o cls token
|
||||||
|
x = x + self.pos_embed[:, 1:, :]
|
||||||
|
|
||||||
|
x = masking(self, x)
|
||||||
|
|
||||||
|
# append cls token
|
||||||
|
cls_token = self.cls_token + model.pos_embed[:, :1, :]
|
||||||
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
# apply Transformer blocks
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def restore_one_image(img, model):
|
||||||
|
x = torch.tensor(img)
|
||||||
|
|
||||||
|
# make it a batch-like
|
||||||
|
x = x.unsqueeze(dim=0)
|
||||||
|
x = torch.einsum('nhwc->nchw', x)
|
||||||
|
|
||||||
|
temp_x = forward_encoder(model, x.float())
|
||||||
|
|
||||||
|
y = model.forward_decoder(temp_x, ids_restore)
|
||||||
|
y = model.unpatchify(y)
|
||||||
|
y = torch.einsum('nchw->nhwc', y).detach().cpu()
|
||||||
|
|
||||||
|
x = torch.einsum('nchw->nhwc', x)
|
||||||
|
|
||||||
|
# make the plt figure larger
|
||||||
|
plt.rcParams['figure.figsize'] = [12, 12]
|
||||||
|
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
show_image(x[0], "original")
|
||||||
|
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
show_image(y[0], "reconstruction")
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
torch.manual_seed(5)
|
||||||
|
print('MAE with pixel reconstruction:')
|
||||||
|
restore_one_image(mine_img, model_mae)
|
||||||
|
|
||||||
|
|
||||||
|
# In[ ]:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user