-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #132 from tmu-nlp/osada
'chapter08'
- Loading branch information
Showing
17 changed files
with
17 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOcTUgfOIIc9GXB0uc6mlCG"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"ddagSuvVYIbG"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["“United States”と”U.S.”のコサイン類似度を計算せよ."],"metadata":{"id":"_xUVvVidBgtv"}},{"cell_type":"code","source":["#61\n","from gensim.models import keyedvectors\n","\n","model = keyedvectors.load_word2vec_format(\"/content/drive/MyDrive/NLP100knock/GoogleNews-vectors-negative300.bin.gz\", binary=True)\n","#similarityでコサイン類似度が求められる\n","print(model.similarity(\"United_States\", \"U.S.\"))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"bTtfChFQx2uG","executionInfo":{"status":"ok","timestamp":1718608297820,"user_tz":-540,"elapsed":71314,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"eb16a928-afba-49c1-c220-b731bfdb3de3"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["0.73107743\n"]}]}]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOYFTDMOm9U8AmmHr8rrk8x"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"W7rG5TaVYM5W"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["“United States”とコサイン類似度が高い10語と,その類似度を出力せよ."],"metadata":{"id":"lNPO6jB1Bq7K"}},{"cell_type":"code","source":["#62\n","from gensim.models import keyedvectors\n","\n","model = keyedvectors.load_word2vec_format(\"/content/drive/MyDrive/NLP100knock/GoogleNews-vectors-negative300.bin.gz\", binary=True)\n","#most_similarでコサイン類似度を求められる,topnオプションで類似度上位の単語数を指定\n","print(model.most_similar(\"United_States\", topn=10))\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6NQmf7bnBfYA","executionInfo":{"status":"ok","timestamp":1718608451755,"user_tz":-540,"elapsed":90538,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"88adda23-7109-40f6-c672-5fb4e6d23773"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[('Unites_States', 0.7877248525619507), ('Untied_States', 0.7541370987892151), ('United_Sates', 0.7400724291801453), ('U.S.', 0.7310774326324463), ('theUnited_States', 0.6404393911361694), ('America', 0.6178410053253174), ('UnitedStates', 0.6167312264442444), ('Europe', 0.6132988929748535), ('countries', 0.6044804453849792), ('Canada', 0.601906955242157)]\n"]}]}]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOOG25V4Coa3Uco8J8JoycJ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"ge3inwrwYRgm"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["“Spain”の単語ベクトルから”Madrid”のベクトルを引き,”Athens”のベクトルを足したベクトルを計算し,そのベクトルと類似度の高い10語とその類似度を出力せよ."],"metadata":{"id":"67WzUjERCUXE"}},{"cell_type":"code","source":["#63\n","from gensim.models import keyedvectors\n","\n","model = keyedvectors.load_word2vec_format(\"/content/drive/MyDrive/NLP100knock/GoogleNews-vectors-negative300.bin.gz\", binary=True)\n","#パラメータpositiveは積極的に関係するキーのリストを指定,negativeは否定的に関係する(要はベクトルの引き算)キーのリストを指定\n","#ref:https://radimrehurek.com/gensim/models/keyedvectors.html\n","print(model.most_similar(positive=[\"Spain\", \"Athens\"], negative=[\"Madrid\"], topn=10))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"IrkuX35CCYVg","executionInfo":{"status":"ok","timestamp":1718608582978,"user_tz":-540,"elapsed":76891,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"5badda7c-69b1-454b-a834-fb03d366f424"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[('Greece', 0.6898480653762817), ('Aristeidis_Grigoriadis', 0.560684859752655), ('Ioannis_Drymonakos', 0.5552908778190613), ('Greeks', 0.545068621635437), ('Ioannis_Christou', 0.5400862097740173), ('Hrysopiyi_Devetzi', 0.5248445272445679), ('Heraklio', 0.5207759737968445), ('Athens_Greece', 0.516880989074707), ('Lithuania', 0.5166865587234497), ('Iraklion', 0.5146791338920593)]\n"]}]}]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOXgtsFLOn64aw6B5I2H6iQ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["単語アナロジーの評価データをダウンロードし,vec(2列目の単語) - vec(1列目の単語) + vec(3列目の単語)を計算し,そのベクトルと類似度が最も高い単語と,その類似度を求めよ.求めた単語と類似度は,各事例の末尾に追記せよ"],"metadata":{"id":"E4_F8DBnCreX"}},{"cell_type":"code","source":["!wget http://download.tensorflow.org/data/questions-words.txt"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"cCpn6VETDMxx","executionInfo":{"status":"ok","timestamp":1718608584229,"user_tz":-540,"elapsed":436,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"51ccbef8-e822-4dbd-9949-78231d62c353"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["--2024-06-17 07:16:23-- http://download.tensorflow.org/data/questions-words.txt\n","Resolving download.tensorflow.org (download.tensorflow.org)... 172.217.203.207, 142.250.98.207, 142.251.107.207, ...\n","Connecting to download.tensorflow.org (download.tensorflow.org)|172.217.203.207|:80... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 603955 (590K) [text/plain]\n","Saving to: ‘questions-words.txt’\n","\n","\rquestions-words.txt 0%[ ] 0 --.-KB/s \rquestions-words.txt 100%[===================>] 589.80K --.-KB/s in 0.003s \n","\n","2024-06-17 07:16:23 (228 MB/s) - ‘questions-words.txt’ saved [603955/603955]\n","\n"]}]},{"cell_type":"code","source":["filepath = \"/content/drive/MyDrive/NLP100knock/GoogleNews-vectors-negative300.bin.gz\"\n","wv_from_bin = KeyedVectors.load_word2vec_format(filepath, binary=True)"],"metadata":{"id":"rJmmyjFkKAd7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#64\n","!pip install tqdm\n","from tqdm import tqdm\n","import re\n","from gensim.models import KeyedVectors\n","\n","datas = []\n","stat = '0'\n","with open('questions-words.txt', 'r') as f1, open('64.txt', 'w') as f2:\n"," for line in tqdm(f1):\n"," if not re.match(r'^:', line):\n"," vecs = line.replace('\\n', '').split(' ')\n"," result = wv_from_bin.most_similar(positive=[vecs[1], vecs[2]], negative=[vecs[0]], topn=1)\n"," vecs.insert(0, stat)\n"," vecs.append(result[0][0])\n"," vecs.append(str(result[0][1]))\n"," datas.append(vecs)\n"," #print('\\r{} {}'.format(result[0][0], stat), end='')\n"," f2.write(line.replace('\\n', '')+' '+result[0][0]+'\\n')\n"," else:\n"," stat = line.replace('\\n', '').replace(': ', '')\n"," f2.write(line)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Q1Q7OAg-CwC_","executionInfo":{"status":"ok","timestamp":1718618997718,"user_tz":-540,"elapsed":8573924,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"95a9314a-4c9f-429a-f587-8ac3b4fc1dd0"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.2)\n"]},{"output_type":"stream","name":"stderr","text":["19558it [2:22:46, 2.28it/s]\n"]}]}]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPzJ/hlr3U8wSV3CEmx2Q8g"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"G971uO2oX-s7"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["64の実行結果を用い,意味的アナロジー(semantic analogy)と文法的アナロジー(syntactic analogy)の正解率を測定せよ."],"metadata":{"id":"7BOjhCIQDxeR"}},{"cell_type":"code","source":["#65\n","import re\n","import numpy as np\n","\n","sem = [d for d in datas if not re.match(r'^gram.*', d[0])]\n","syn= [d for d in datas if re.match(r'^gram.*', d[0])]\n","\n","print(\"semantic analogy: {}\".format(np.mean([e[4]==e[5] for e in sem])))\n","print(\"syntantic analogy: {}\".format(np.mean([e[4]==e[5] for e in syn])))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"17E0MrmID1jD","executionInfo":{"status":"ok","timestamp":1718619359260,"user_tz":-540,"elapsed":322,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"0b63fd67-4ed2-4522-f83e-4c6a5e9041ef"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["semantic analogy: 0.7308602999210734\n","syntantic analogy: 0.7400468384074942\n"]}]}]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNpTQUlTeank0Ntw1yR1Ogg"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"M_yCKiR_YsNZ"},"outputs":[],"source":[]},{"cell_type":"markdown","source":["The WordSimilarity-353 Test Collectionの評価データをダウンロードし,単語ベクトルにより計算される類似度のランキングと,人間の類似度判定のランキングの間のスピアマン相関係数を計算せよ."],"metadata":{"id":"PznaJKLMERf_"}},{"cell_type":"code","source":["import pandas as pd\n","\n","pd.DataFrame(\n"," data[:min(15, len(data))],\n"," columns = ['Word1', 'Word2', 'Human(mean)', 'Similarity(Vec)']\n",")\n","#うまく動かない"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":216},"id":"pqgjIx6yOmos","executionInfo":{"status":"error","timestamp":1718619406169,"user_tz":-540,"elapsed":2389,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"4a8cd0b0-559d-4ea6-b3d5-5073bc6d4709"},"execution_count":null,"outputs":[{"output_type":"error","ename":"NameError","evalue":"name 'data' is not defined","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-24-723a7239b136>\u001b[0m in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m pd.DataFrame(\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m15\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'Word1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Word2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Human(mean)'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Similarity(Vec)'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m )\n","\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"]}]},{"cell_type":"code","source":["from scipy.stats import rankdata\n","\n","def spearman(lst1, lst2):\n"," lst1 = np.array(lst1)\n"," lst2 = np.array(lst2)\n"," N = len(lst1)\n"," return 1 - (6*sum((lst1-lst2)**2) / (N**3-N))\n","\n","print(spearman(\n"," rankdata([float(d[2]) for d in data]),\n"," rankdata([float(d[3]) for d in data])\n","))"],"metadata":{"id":"K04WYDO4O_YM"},"execution_count":null,"outputs":[]}]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPT7m6BVFCMZ78lqtMnWxYa"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"smavY3o_Yx-N"},"outputs":[],"source":["\n"]},{"cell_type":"markdown","source":["国名に関する単語ベクトルを抽出し,k-meansクラスタリングをクラスタ数k=5として実行せよ."],"metadata":{"id":"HuP1sb7iPQ3o"}},{"cell_type":"code","source":["#67\n","import re\n","\n","data = []\n","stat = '0'\n","with open('64.txt', 'r') as f:\n"," for line in f:\n"," if not re.match(r'^:', line):\n"," vecs = line.replace('\\n', '').split(' ')\n"," vecs.insert(0, stat)\n"," data.append(vecs)\n"," else:\n"," stat = line.replace('\\n', '').replace(': ', '')"],"metadata":{"id":"LRjmrM9vPQM3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["countries = {\n"," c\n"," for d in data\n"," for c in [d[2], d[4]]\n"," if d[0] in ['capital-common-countries', 'capital-world']\n","}\n","countries = list(countries)\n"],"metadata":{"id":"CAbOoKCNPaq9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from sklearn.cluster import KMeans\n","\n","kmeans = KMeans(n_clusters=5)\n","kmeans.fit([wv_from_bin[c] for c in countries])\n","\n","for i in range(5):\n"," cluster = np.where(kmeans.labels_ == i)[0]\n"," print('クラス', i)\n"," print(', '.join([countries[k] for k in cluster]))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Pqqj5AWjPdRb","executionInfo":{"status":"ok","timestamp":1718619945895,"user_tz":-540,"elapsed":1558,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"24359787-a96e-4e8d-d2c5-d36b6a8032f9"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/sklearn/cluster/_kmeans.py:870: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["クラス 0\n","Jamaica, Peru, Venezuela, Guyana, Philippines, Dominica, Nicaragua, Belize, Uruguay, Ecuador, Taiwan, Samoa, Honduras, Suriname, Bahamas, Cuba, Fiji, Chile, Tuvalu\n","クラス 1\n","Latvia, Kyrgyzstan, Croatia, Moldova, Montenegro, Hungary, Armenia, Cyprus, Azerbaijan, Ukraine, Uzbekistan, Tajikistan, Turkmenistan, Bulgaria, Poland, Macedonia, Serbia, Kazakhstan, Greece, Romania, Estonia, Russia, Georgia, Turkey, Slovakia, Belarus, Slovenia, Lithuania, Albania\n","クラス 2\n","Iraq, Egypt, Vietnam, Eritrea, Syria, Thailand, Nepal, Bhutan, China, Algeria, Bahrain, Tunisia, Mauritania, Pakistan, Bangladesh, Qatar, Lebanon, Libya, Afghanistan, Indonesia, Sudan, Oman, Morocco, Somalia, Iran, Jordan, Laos\n","クラス 3\n","Belgium, Ireland, England, Finland, Canada, Portugal, Malta, Norway, France, Liechtenstein, Japan, Austria, Switzerland, Greenland, Sweden, Germany, Spain, Denmark, Australia, Italy\n","クラス 4\n","Malawi, Niger, Gabon, Burundi, Rwanda, Angola, Guinea, Senegal, Uganda, Ghana, Mozambique, Mali, Zimbabwe, Madagascar, Namibia, Gambia, Kenya, Liberia, Botswana, Nigeria, Zambia\n"]}]}]} |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNWZFia+pJzx3/DaRcorPY4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["#73\n","from tqdm import tqdm\n","import torch\n","import numpy as np\n","\n","#学習データ読み込み\n","x_train = np.loadtxt(\"./x_train.txt\", delimiter=\" \")\n","y_train = np.loadtxt(\"./y_train.txt\")\n","#pytorchのtensor型にする\n","x_train = torch.tensor(x_train, dtype=torch.float32)\n","y_train = torch.tensor(y_train, dtype=torch.int64)\n","\n","#ネットワーク作成\n","net = torch.nn.Linear(300, 4, bias=False)\n","loss_fn = torch.nn.CrossEntropyLoss()\n","\n","#確率的勾配降下法SGD(モデルのパラメータ,学習率)\n","optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n","\n","#100epochで学習終了\n","losses = []\n","for epoch in tqdm(range(100)):\n"," optimizer.zero_grad()#勾配を0で初期化\n"," y_pred = torch.softmax(net.forward(x_train), dim=1)\n"," loss = loss_fn(y_pred, y_train)#損失を計算\n"," loss.backward()#勾配を計算\n"," optimizer.step()#最適化ステップを実行\n"," losses.append(loss)\n","\n","#学習したモデルを保存\n","torch.save(net.state_dict(), \"model.pt\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"LdayxCzmcjor","executionInfo":{"status":"ok","timestamp":1719168860143,"user_tz":-540,"elapsed":3628,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"7bb995c7-61bb-4ca3-9d2a-341062109f72"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["100%|██████████| 100/100 [00:00<00:00, 110.35it/s]\n"]}]}]} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPK9YqOCedfk6MISIJI81zX"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["#71\n","import torch\n","import numpy as np\n","\n","#学習データ読み込み\n","x_train = np.loadtxt(\"./x_train.txt\", delimiter=\" \")\n","#pytorchのtensor型にする(32bit浮動小数点),サイズは10684×300\n","x_train = torch.tensor(x_train, dtype=torch.float32)\n","#print(x_train.size())\n","\n","#重み行列Wはランダムな値で初期化\n","W = torch.rand(300, 4)\n","\n","#softmaxで予測確率を得る,dim=1にすると行単位で合計1にしてくれる\n","softmax = torch.nn.Softmax(dim=1)\n","#matmulでtensorの行列積求める\n","print(softmax(torch.matmul(x_train[:1], W)))\n","print(softmax(torch.matmul(x_train[:4], W)))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zfe0EesObYrM","executionInfo":{"status":"ok","timestamp":1719168546814,"user_tz":-540,"elapsed":691,"user":{"displayName":"OSADA MANATO","userId":"03457769063882444574"}},"outputId":"1bed2f24-02f0-470e-fa10-fca8554ee3e1"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["tensor([[0.2263, 0.2019, 0.3276, 0.2442]])\n","tensor([[0.2263, 0.2019, 0.3276, 0.2442],\n"," [0.2009, 0.2275, 0.2434, 0.3281],\n"," [0.2149, 0.3109, 0.2759, 0.1983],\n"," [0.1952, 0.2211, 0.4297, 0.1539]])\n"]}]}]} |
Oops, something went wrong.