Skip to content

Commit

Permalink
Merge pull request #132 from tmu-nlp/osada
Browse files Browse the repository at this point in the history
'chapter08'
  • Loading branch information
kiyama-hajime authored Jun 24, 2024
2 parents 2e20a5c + 17bc2d0 commit 533fac3
Show file tree
Hide file tree
Showing 17 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions osada/chapter07/knock60.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions osada/chapter07/knock61.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOcTUgfOIIc9GXB0uc6mlCG"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"ddagSuvVYIbG"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["“United States”と”U.S.”のコサイン類似度を計算せよ."],"metadata":{"id":"_xUVvVidBgtv"}},{"cell_type":"code","source":["#61\n","from gensim.models import keyedvectors\n","\n","model = keyedvectors.load_word2vec_format(\"/content/drive/MyDrive/NLP100knock/GoogleNews-vectors-negative300.bin.gz\", binary=True)\n","#similarityでコサイン類似度が求められる\n","print(model.similarity(\"United_States\", \"U.S.\"))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"bTtfChFQx2uG","executionInfo":{"status":"ok","timestamp":1718608297820,"user_tz":-540,"elapsed":71314,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"eb16a928-afba-49c1-c220-b731bfdb3de3"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["0.73107743\n"]}]}]}
1 change: 1 addition & 0 deletions osada/chapter07/knock62.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOYFTDMOm9U8AmmHr8rrk8x"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"W7rG5TaVYM5W"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["“United States”とコサイン類似度が高い10語と,その類似度を出力せよ."],"metadata":{"id":"lNPO6jB1Bq7K"}},{"cell_type":"code","source":["#62\n","from gensim.models import keyedvectors\n","\n","model = keyedvectors.load_word2vec_format(\"/content/drive/MyDrive/NLP100knock/GoogleNews-vectors-negative300.bin.gz\", binary=True)\n","#most_similarでコサイン類似度を求められる,topnオプションで類似度上位の単語数を指定\n","print(model.most_similar(\"United_States\", topn=10))\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6NQmf7bnBfYA","executionInfo":{"status":"ok","timestamp":1718608451755,"user_tz":-540,"elapsed":90538,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"88adda23-7109-40f6-c672-5fb4e6d23773"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[('Unites_States', 0.7877248525619507), ('Untied_States', 0.7541370987892151), ('United_Sates', 0.7400724291801453), ('U.S.', 0.7310774326324463), ('theUnited_States', 0.6404393911361694), ('America', 0.6178410053253174), ('UnitedStates', 0.6167312264442444), ('Europe', 0.6132988929748535), ('countries', 0.6044804453849792), ('Canada', 0.601906955242157)]\n"]}]}]}
1 change: 1 addition & 0 deletions osada/chapter07/knock63.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOOG25V4Coa3Uco8J8JoycJ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"ge3inwrwYRgm"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["“Spain”の単語ベクトルから”Madrid”のベクトルを引き,”Athens”のベクトルを足したベクトルを計算し,そのベクトルと類似度の高い10語とその類似度を出力せよ."],"metadata":{"id":"67WzUjERCUXE"}},{"cell_type":"code","source":["#63\n","from gensim.models import keyedvectors\n","\n","model = keyedvectors.load_word2vec_format(\"/content/drive/MyDrive/NLP100knock/GoogleNews-vectors-negative300.bin.gz\", binary=True)\n","#パラメータpositiveは積極的に関係するキーのリストを指定,negativeは否定的に関係する(要はベクトルの引き算)キーのリストを指定\n","#ref:https://radimrehurek.com/gensim/models/keyedvectors.html\n","print(model.most_similar(positive=[\"Spain\", \"Athens\"], negative=[\"Madrid\"], topn=10))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"IrkuX35CCYVg","executionInfo":{"status":"ok","timestamp":1718608582978,"user_tz":-540,"elapsed":76891,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"5badda7c-69b1-454b-a834-fb03d366f424"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[('Greece', 0.6898480653762817), ('Aristeidis_Grigoriadis', 0.560684859752655), ('Ioannis_Drymonakos', 0.5552908778190613), ('Greeks', 0.545068621635437), ('Ioannis_Christou', 0.5400862097740173), ('Hrysopiyi_Devetzi', 0.5248445272445679), ('Heraklio', 0.5207759737968445), ('Athens_Greece', 0.516880989074707), ('Lithuania', 0.5166865587234497), ('Iraklion', 0.5146791338920593)]\n"]}]}]}
1 change: 1 addition & 0 deletions osada/chapter07/knock64.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOXgtsFLOn64aw6B5I2H6iQ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["単語アナロジーの評価データをダウンロードし,vec(2列目の単語) - vec(1列目の単語) + vec(3列目の単語)を計算し,そのベクトルと類似度が最も高い単語と,その類似度を求めよ.求めた単語と類似度は,各事例の末尾に追記せよ"],"metadata":{"id":"E4_F8DBnCreX"}},{"cell_type":"code","source":["!wget http://download.tensorflow.org/data/questions-words.txt"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"cCpn6VETDMxx","executionInfo":{"status":"ok","timestamp":1718608584229,"user_tz":-540,"elapsed":436,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"51ccbef8-e822-4dbd-9949-78231d62c353"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["--2024-06-17 07:16:23-- http://download.tensorflow.org/data/questions-words.txt\n","Resolving download.tensorflow.org (download.tensorflow.org)... 172.217.203.207, 142.250.98.207, 142.251.107.207, ...\n","Connecting to download.tensorflow.org (download.tensorflow.org)|172.217.203.207|:80... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 603955 (590K) [text/plain]\n","Saving to: ‘questions-words.txt’\n","\n","\rquestions-words.txt 0%[ ] 0 --.-KB/s \rquestions-words.txt 100%[===================>] 589.80K --.-KB/s in 0.003s \n","\n","2024-06-17 07:16:23 (228 MB/s) - ‘questions-words.txt’ saved [603955/603955]\n","\n"]}]},{"cell_type":"code","source":["filepath = \"/content/drive/MyDrive/NLP100knock/GoogleNews-vectors-negative300.bin.gz\"\n","wv_from_bin = KeyedVectors.load_word2vec_format(filepath, binary=True)"],"metadata":{"id":"rJmmyjFkKAd7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#64\n","!pip install tqdm\n","from tqdm import tqdm\n","import re\n","from gensim.models import KeyedVectors\n","\n","datas = []\n","stat = '0'\n","with open('questions-words.txt', 'r') as f1, open('64.txt', 'w') as f2:\n"," for line in tqdm(f1):\n"," if not re.match(r'^:', line):\n"," vecs = line.replace('\\n', '').split(' ')\n"," result = wv_from_bin.most_similar(positive=[vecs[1], vecs[2]], negative=[vecs[0]], topn=1)\n"," vecs.insert(0, stat)\n"," vecs.append(result[0][0])\n"," vecs.append(str(result[0][1]))\n"," datas.append(vecs)\n"," #print('\\r{} {}'.format(result[0][0], stat), end='')\n"," f2.write(line.replace('\\n', '')+' '+result[0][0]+'\\n')\n"," else:\n"," stat = line.replace('\\n', '').replace(': ', '')\n"," f2.write(line)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Q1Q7OAg-CwC_","executionInfo":{"status":"ok","timestamp":1718618997718,"user_tz":-540,"elapsed":8573924,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"95a9314a-4c9f-429a-f587-8ac3b4fc1dd0"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.2)\n"]},{"output_type":"stream","name":"stderr","text":["19558it [2:22:46, 2.28it/s]\n"]}]}]}
1 change: 1 addition & 0 deletions osada/chapter07/knock65.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPzJ/hlr3U8wSV3CEmx2Q8g"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"G971uO2oX-s7"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["64の実行結果を用い,意味的アナロジー(semantic analogy)と文法的アナロジー(syntactic analogy)の正解率を測定せよ."],"metadata":{"id":"7BOjhCIQDxeR"}},{"cell_type":"code","source":["#65\n","import re\n","import numpy as np\n","\n","sem = [d for d in datas if not re.match(r'^gram.*', d[0])]\n","syn= [d for d in datas if re.match(r'^gram.*', d[0])]\n","\n","print(\"semantic analogy: {}\".format(np.mean([e[4]==e[5] for e in sem])))\n","print(\"syntantic analogy: {}\".format(np.mean([e[4]==e[5] for e in syn])))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"17E0MrmID1jD","executionInfo":{"status":"ok","timestamp":1718619359260,"user_tz":-540,"elapsed":322,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"0b63fd67-4ed2-4522-f83e-4c6a5e9041ef"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["semantic analogy: 0.7308602999210734\n","syntantic analogy: 0.7400468384074942\n"]}]}]}
1 change: 1 addition & 0 deletions osada/chapter07/knock66.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNpTQUlTeank0Ntw1yR1Ogg"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"M_yCKiR_YsNZ"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["The WordSimilarity-353 Test Collectionの評価データをダウンロードし,単語ベクトルにより計算される類似度のランキングと,人間の類似度判定のランキングの間のスピアマン相関係数を計算せよ."],"metadata":{"id":"PznaJKLMERf_"}},{"cell_type":"code","source":["import pandas as pd\n","\n","pd.DataFrame(\n"," data[:min(15, len(data))],\n"," columns = ['Word1', 'Word2', 'Human(mean)', 'Similarity(Vec)']\n",")\n","#うまく動かない"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":216},"id":"pqgjIx6yOmos","executionInfo":{"status":"error","timestamp":1718619406169,"user_tz":-540,"elapsed":2389,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"4a8cd0b0-559d-4ea6-b3d5-5073bc6d4709"},"execution_count":null,"outputs":[{"output_type":"error","ename":"NameError","evalue":"name 'data' is not defined","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-24-723a7239b136>\u001b[0m in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m pd.DataFrame(\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m15\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'Word1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Word2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Human(mean)'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Similarity(Vec)'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m )\n","\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"]}]},{"cell_type":"code","source":["from scipy.stats import rankdata\n","\n","def spearman(lst1, lst2):\n"," lst1 = np.array(lst1)\n"," lst2 = np.array(lst2)\n"," N = len(lst1)\n"," return 1 - (6*sum((lst1-lst2)**2) / (N**3-N))\n","\n","print(spearman(\n"," rankdata([float(d[2]) for d in data]),\n"," rankdata([float(d[3]) for d in data])\n","))"],"metadata":{"id":"K04WYDO4O_YM"},"execution_count":null,"outputs":[]}]}
1 change: 1 addition & 0 deletions osada/chapter07/knock67.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPT7m6BVFCMZ78lqtMnWxYa"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"smavY3o_Yx-N"},"outputs":[],"source":["\n"]},{"cell_type":"markdown","source":["国名に関する単語ベクトルを抽出し,k-meansクラスタリングをクラスタ数k=5として実行せよ."],"metadata":{"id":"HuP1sb7iPQ3o"}},{"cell_type":"code","source":["#67\n","import re\n","\n","data = []\n","stat = '0'\n","with open('64.txt', 'r') as f:\n"," for line in f:\n"," if not re.match(r'^:', line):\n"," vecs = line.replace('\\n', '').split(' ')\n"," vecs.insert(0, stat)\n"," data.append(vecs)\n"," else:\n"," stat = line.replace('\\n', '').replace(': ', '')"],"metadata":{"id":"LRjmrM9vPQM3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["countries = {\n"," c\n"," for d in data\n"," for c in [d[2], d[4]]\n"," if d[0] in ['capital-common-countries', 'capital-world']\n","}\n","countries = list(countries)\n"],"metadata":{"id":"CAbOoKCNPaq9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from sklearn.cluster import KMeans\n","\n","kmeans = KMeans(n_clusters=5)\n","kmeans.fit([wv_from_bin[c] for c in countries])\n","\n","for i in range(5):\n"," cluster = np.where(kmeans.labels_ == i)[0]\n"," print('クラス', i)\n"," print(', '.join([countries[k] for k in cluster]))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Pqqj5AWjPdRb","executionInfo":{"status":"ok","timestamp":1718619945895,"user_tz":-540,"elapsed":1558,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"24359787-a96e-4e8d-d2c5-d36b6a8032f9"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/sklearn/cluster/_kmeans.py:870: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["クラス 0\n","Jamaica, Peru, Venezuela, Guyana, Philippines, Dominica, Nicaragua, Belize, Uruguay, Ecuador, Taiwan, Samoa, Honduras, Suriname, Bahamas, Cuba, Fiji, Chile, Tuvalu\n","クラス 1\n","Latvia, Kyrgyzstan, Croatia, Moldova, Montenegro, Hungary, Armenia, Cyprus, Azerbaijan, Ukraine, Uzbekistan, Tajikistan, Turkmenistan, Bulgaria, Poland, Macedonia, Serbia, Kazakhstan, Greece, Romania, Estonia, Russia, Georgia, Turkey, Slovakia, Belarus, Slovenia, Lithuania, Albania\n","クラス 2\n","Iraq, Egypt, Vietnam, Eritrea, Syria, Thailand, Nepal, Bhutan, China, Algeria, Bahrain, Tunisia, Mauritania, Pakistan, Bangladesh, Qatar, Lebanon, Libya, Afghanistan, Indonesia, Sudan, Oman, Morocco, Somalia, Iran, Jordan, Laos\n","クラス 3\n","Belgium, Ireland, England, Finland, Canada, Portugal, Malta, Norway, France, Liechtenstein, Japan, Austria, Switzerland, Greenland, Sweden, Germany, Spain, Denmark, Australia, Italy\n","クラス 4\n","Malawi, Niger, Gabon, Burundi, Rwanda, Angola, Guinea, Senegal, Uganda, Ghana, Mozambique, Mali, Zimbabwe, Madagascar, Namibia, Gambia, Kenya, Liberia, Botswana, Nigeria, Zambia\n"]}]}]}
1 change: 1 addition & 0 deletions osada/chapter07/knock68.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions osada/chapter07/knock69.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions osada/chapter08/knck73.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNWZFia+pJzx3/DaRcorPY4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["#73\n","from tqdm import tqdm\n","import torch\n","import numpy as np\n","\n","#学習データ読み込み\n","x_train = np.loadtxt(\"./x_train.txt\", delimiter=\" \")\n","y_train = np.loadtxt(\"./y_train.txt\")\n","#pytorchのtensor型にする\n","x_train = torch.tensor(x_train, dtype=torch.float32)\n","y_train = torch.tensor(y_train, dtype=torch.int64)\n","\n","#ネットワーク作成\n","net = torch.nn.Linear(300, 4, bias=False)\n","loss_fn = torch.nn.CrossEntropyLoss()\n","\n","#確率的勾配降下法SGD(モデルのパラメータ,学習率)\n","optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n","\n","#100epochで学習終了\n","losses = []\n","for epoch in tqdm(range(100)):\n"," optimizer.zero_grad()#勾配を0で初期化\n"," y_pred = torch.softmax(net.forward(x_train), dim=1)\n"," loss = loss_fn(y_pred, y_train)#損失を計算\n"," loss.backward()#勾配を計算\n"," optimizer.step()#最適化ステップを実行\n"," losses.append(loss)\n","\n","#学習したモデルを保存\n","torch.save(net.state_dict(), \"model.pt\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"LdayxCzmcjor","executionInfo":{"status":"ok","timestamp":1719168860143,"user_tz":-540,"elapsed":3628,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"7bb995c7-61bb-4ca3-9d2a-341062109f72"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["100%|██████████| 100/100 [00:00<00:00, 110.35it/s]\n"]}]}]}
1 change: 1 addition & 0 deletions osada/chapter08/knock70.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions osada/chapter08/knock71.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPK9YqOCedfk6MISIJI81zX"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["#71\n","import torch\n","import numpy as np\n","\n","#学習データ読み込み\n","x_train = np.loadtxt(\"./x_train.txt\", delimiter=\" \")\n","#pytorchのtensor型にする(32bit浮動小数点),サイズは10684×300\n","x_train = torch.tensor(x_train, dtype=torch.float32)\n","#print(x_train.size())\n","\n","#重み行列Wはランダムな値で初期化\n","W = torch.rand(300, 4)\n","\n","#softmaxで予測確率を得る,dim=1にすると行単位で合計1にしてくれる\n","softmax = torch.nn.Softmax(dim=1)\n","#matmulでtensorの行列積求める\n","print(softmax(torch.matmul(x_train[:1], W)))\n","print(softmax(torch.matmul(x_train[:4], W)))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zfe0EesObYrM","executionInfo":{"status":"ok","timestamp":1719168546814,"user_tz":-540,"elapsed":691,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"1bed2f24-02f0-470e-fa10-fca8554ee3e1"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["tensor([[0.2263, 0.2019, 0.3276, 0.2442]])\n","tensor([[0.2263, 0.2019, 0.3276, 0.2442],\n"," [0.2009, 0.2275, 0.2434, 0.3281],\n"," [0.2149, 0.3109, 0.2759, 0.1983],\n"," [0.1952, 0.2211, 0.4297, 0.1539]])\n"]}]}]}
Loading

0 comments on commit 533fac3

Please sign in to comment.