-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #115 from tmu-nlp/ritsu
add: chapter07
- Loading branch information
Showing
12 changed files
with
425 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from gensim.models import KeyedVectors | ||
|
||
def load_word_vectors(file_path): | ||
return KeyedVectors.load_word2vec_format(file_path, binary=True) | ||
|
||
def main(): | ||
file_path = 'GoogleNews-vectors-negative300.bin.gz' | ||
word_vectors = load_word_vectors(file_path) | ||
|
||
word = 'United_States' | ||
if word in word_vectors: | ||
vector = word_vectors[word] | ||
print(f"単語 '{word}' のベクトル:") | ||
print(vector) | ||
else: | ||
print(f"単語 '{word}' はモデルに存在しません。") | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
""" | ||
[-3.61328125e-02 -4.83398438e-02 2.35351562e-01 1.74804688e-01 | ||
-1.46484375e-01 -7.42187500e-02 -1.01562500e-01 -7.71484375e-02 | ||
1.09375000e-01 -5.71289062e-02 -1.48437500e-01 -6.00585938e-02 | ||
1.74804688e-01 -7.71484375e-02 2.58789062e-02 -7.66601562e-02 | ||
-3.80859375e-02 1.35742188e-01 3.75976562e-02 -4.19921875e-02 | ||
-3.56445312e-02 5.34667969e-02 3.68118286e-04 -1.66992188e-01 | ||
-1.17187500e-01 1.41601562e-01 -1.69921875e-01 -6.49414062e-02 | ||
-1.66992188e-01 1.00585938e-01 1.15722656e-01 -2.18750000e-01 | ||
-9.86328125e-02 -2.56347656e-02 1.23046875e-01 -3.54003906e-02 | ||
-1.58203125e-01 -1.60156250e-01 2.94189453e-02 8.15429688e-02 | ||
6.88476562e-02 1.87500000e-01 6.49414062e-02 1.15234375e-01 | ||
-2.27050781e-02 3.32031250e-01 -3.27148438e-02 1.77734375e-01 | ||
-2.08007812e-01 4.54101562e-02 -1.23901367e-02 1.19628906e-01 | ||
7.44628906e-03 -9.03320312e-03 1.14257812e-01 1.69921875e-01 | ||
-2.38281250e-01 -2.79541016e-02 -1.21093750e-01 2.47802734e-02 | ||
7.71484375e-02 -2.81982422e-02 -4.71191406e-02 1.78222656e-02 | ||
-1.23046875e-01 -5.32226562e-02 2.68554688e-02 -3.11279297e-02 | ||
-5.59082031e-02 -5.00488281e-02 -3.73535156e-02 1.25976562e-01 | ||
5.61523438e-02 1.51367188e-01 4.29687500e-02 -2.08007812e-01 | ||
-4.78515625e-02 2.78320312e-02 1.81640625e-01 2.20703125e-01 | ||
-3.61328125e-02 -8.39843750e-02 -3.69548798e-05 -9.52148438e-02 | ||
-1.25000000e-01 -1.95312500e-01 -1.50390625e-01 -4.15039062e-02 | ||
1.31835938e-01 1.17675781e-01 1.91650391e-02 5.51757812e-02 | ||
-9.42382812e-02 -1.08886719e-01 7.32421875e-02 -1.15234375e-01 | ||
8.93554688e-02 -1.40625000e-01 1.45507812e-01 4.49218750e-02 | ||
-1.10473633e-02 -1.62353516e-02 4.05883789e-03 3.75976562e-02 | ||
-6.98242188e-02 -5.46875000e-02 2.17285156e-02 -9.47265625e-02 | ||
4.24804688e-02 1.81884766e-02 -1.73339844e-02 4.63867188e-02 | ||
-1.42578125e-01 1.99218750e-01 1.10839844e-01 2.58789062e-02 | ||
-7.08007812e-02 -5.54199219e-02 3.45703125e-01 1.61132812e-01 | ||
-2.44140625e-01 -2.59765625e-01 -9.71679688e-02 8.00781250e-02 | ||
-8.78906250e-02 -7.22656250e-02 1.42578125e-01 -8.54492188e-02 | ||
-3.18359375e-01 8.30078125e-02 6.34765625e-02 1.64062500e-01 | ||
-1.92382812e-01 -1.17675781e-01 -5.41992188e-02 -1.56250000e-01 | ||
-1.21582031e-01 -4.95605469e-02 1.20117188e-01 -3.83300781e-02 | ||
5.51757812e-02 -8.97216797e-03 4.32128906e-02 6.93359375e-02 | ||
8.93554688e-02 2.53906250e-01 1.65039062e-01 1.64062500e-01 | ||
-1.41601562e-01 4.58984375e-02 1.97265625e-01 -8.98437500e-02 | ||
3.90625000e-02 -1.51367188e-01 -8.60595703e-03 -1.17675781e-01 | ||
-1.97265625e-01 -1.12792969e-01 1.29882812e-01 1.96289062e-01 | ||
1.56402588e-03 3.93066406e-02 2.17773438e-01 -1.43554688e-01 | ||
6.03027344e-02 -1.35742188e-01 1.16210938e-01 -1.59912109e-02 | ||
2.79296875e-01 1.46484375e-01 -1.19628906e-01 1.76757812e-01 | ||
1.28906250e-01 -1.49414062e-01 6.93359375e-02 -1.72851562e-01 | ||
9.22851562e-02 1.33056641e-02 -2.00195312e-01 -9.76562500e-02 | ||
-1.65039062e-01 -2.46093750e-01 -2.35595703e-02 -2.11914062e-01 | ||
1.84570312e-01 -1.85546875e-02 2.16796875e-01 5.05371094e-02 | ||
2.02636719e-02 4.25781250e-01 1.28906250e-01 -2.77099609e-02 | ||
1.29882812e-01 -1.15722656e-01 -2.05078125e-02 1.49414062e-01 | ||
7.81250000e-03 -2.05078125e-01 -8.05664062e-02 -2.67578125e-01 | ||
-2.29492188e-02 -8.20312500e-02 8.64257812e-02 7.61718750e-02 | ||
-3.66210938e-02 5.22460938e-02 -1.22070312e-01 -1.44042969e-02 | ||
-2.69531250e-01 8.44726562e-02 -2.52685547e-02 -2.96630859e-02 | ||
-1.68945312e-01 1.93359375e-01 -1.08398438e-01 1.94091797e-02 | ||
-1.80664062e-01 1.93359375e-01 -7.08007812e-02 5.85937500e-02 | ||
-1.01562500e-01 -1.31835938e-01 7.51953125e-02 -7.66601562e-02 | ||
3.37219238e-03 -8.59375000e-02 1.25000000e-01 2.92968750e-02 | ||
1.70898438e-01 -9.37500000e-02 -1.09375000e-01 -2.50244141e-02 | ||
2.11914062e-01 -4.44335938e-02 6.12792969e-02 2.62451172e-02 | ||
-1.77734375e-01 1.23046875e-01 -7.42187500e-02 -1.67968750e-01 | ||
-1.08886719e-01 -9.04083252e-04 -7.37304688e-02 5.49316406e-02 | ||
6.03027344e-02 8.39843750e-02 9.17968750e-02 -1.32812500e-01 | ||
1.22070312e-01 -8.78906250e-03 1.19140625e-01 -1.94335938e-01 | ||
-6.64062500e-02 -2.07031250e-01 7.37304688e-02 8.93554688e-02 | ||
1.81884766e-02 -1.20605469e-01 -2.61230469e-02 2.67333984e-02 | ||
7.76367188e-02 -8.30078125e-02 6.78710938e-02 -3.54003906e-02 | ||
3.10546875e-01 -2.42919922e-02 -1.41601562e-01 -2.08007812e-01 | ||
-4.57763672e-03 -6.54296875e-02 -4.95605469e-02 2.22656250e-01 | ||
1.53320312e-01 -1.38671875e-01 -5.24902344e-02 4.24804688e-02 | ||
-2.38281250e-01 1.56250000e-01 5.83648682e-04 -1.20605469e-01 | ||
-9.22851562e-02 -4.44335938e-02 3.61328125e-02 -1.86767578e-02 | ||
-8.25195312e-02 -8.25195312e-02 -4.05273438e-02 1.19018555e-02 | ||
1.69921875e-01 -2.80761719e-02 3.03649902e-03 9.32617188e-02 | ||
-8.49609375e-02 1.57470703e-02 7.03125000e-02 1.62353516e-02 | ||
-2.27050781e-02 3.51562500e-02 2.47070312e-01 -2.67333984e-02] | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from knock60 import load_word_vectors | ||
|
||
def main(): | ||
""" | ||
メイン関数 | ||
""" | ||
file_path = 'GoogleNews-vectors-negative300.bin.gz' | ||
word_vectors = load_word_vectors(file_path) | ||
|
||
word1 = 'United_States' | ||
word2 = 'U.S.' | ||
|
||
try: | ||
cosine_similarity = word_vectors.similarity(word1, word2) | ||
print(f"単語 '{word1}' と '{word2}' のコサイン類似度: {cosine_similarity}") | ||
except KeyError as e: | ||
print(f"単語 '{e.args[0]}' はモデルに存在しません。") | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
""" | ||
単語 'United_States' と 'U.S.' のコサイン類似度: 0.7310774326324463 | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from knock60 import load_word_vectors | ||
|
||
def main(): | ||
""" | ||
メイン関数 | ||
""" | ||
file_path = 'GoogleNews-vectors-negative300.bin.gz' | ||
word_vectors = load_word_vectors(file_path) | ||
|
||
word = 'United_States' | ||
|
||
try: | ||
similar_words = word_vectors.most_similar(word, topn=10) | ||
print(f"単語 '{word}' と類似度の高い上位10語:") | ||
for similar_word, similarity in similar_words: | ||
print(f"単語: {similar_word}, 類似度: {similarity}") | ||
except KeyError as e: | ||
print(f"単語 '{e.args[0]}' はモデルに存在しません。") | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
""" | ||
単語 'United_States' と類似度の高い上位10語: | ||
単語: Unites_States, 類似度: 0.7877248525619507 | ||
単語: Untied_States, 類似度: 0.7541370987892151 | ||
単語: United_Sates, 類似度: 0.7400724291801453 | ||
単語: U.S., 類似度: 0.7310773730278015 | ||
単語: theUnited_States, 類似度: 0.6404393911361694 | ||
単語: America, 類似度: 0.6178409457206726 | ||
単語: UnitedStates, 類似度: 0.6167312264442444 | ||
単語: Europe, 類似度: 0.6132988333702087 | ||
単語: countries, 類似度: 0.6044804453849792 | ||
単語: Canada, 類似度: 0.6019070148468018 | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from knock60 import load_word_vectors | ||
|
||
def main(): | ||
""" | ||
メイン関数 | ||
""" | ||
file_path = 'GoogleNews-vectors-negative300.bin.gz' | ||
word_vectors = load_word_vectors(file_path) | ||
|
||
word1 = 'Spain' | ||
word2 = 'Madrid' | ||
word3 = 'Athens' # Athensはギリシャの首都 | ||
|
||
try: | ||
vector = word_vectors[word1] - word_vectors[word2] + word_vectors[word3] | ||
similar_words = word_vectors.similar_by_vector(vector, topn=10) | ||
|
||
print(f"単語 '{word1}' から '{word2}' を引き、'{word3}' を足したベクトルと類似度の高い上位10語:") | ||
for similar_word, similarity in similar_words: | ||
print(f"単語: {similar_word}, 類似度: {similarity}") | ||
except KeyError as e: | ||
print(f"単語 '{e.args[0]}' はモデルに存在しません。") | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
""" | ||
単語 'Spain' から 'Madrid' を引き、'Athens' を足したベクトルと類似度の高い上位10語: | ||
単語: Athens, 類似度: 0.7528456449508667 | ||
単語: Greece, 類似度: 0.6685471534729004 | ||
単語: Aristeidis_Grigoriadis, 類似度: 0.5495778322219849 | ||
単語: Ioannis_Drymonakos, 類似度: 0.5361456871032715 | ||
単語: Greeks, 類似度: 0.5351786613464355 | ||
単語: Ioannis_Christou, 類似度: 0.5330225825309753 | ||
単語: Hrysopiyi_Devetzi, 類似度: 0.5088489651679993 | ||
単語: Iraklion, 類似度: 0.5059264898300171 | ||
単語: Greek, 類似度: 0.5040615797042847 | ||
単語: Athens_Greece, 類似度: 0.5034109950065613 | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from gensim.models import KeyedVectors | ||
from tqdm import tqdm | ||
|
||
def main(): | ||
model_path = 'GoogleNews-vectors-negative300.bin.gz' | ||
model = KeyedVectors.load_word2vec_format(model_path, binary=True) | ||
|
||
questions_file = 'questions-words.txt' | ||
output_file = 'questions-words-add.txt' | ||
|
||
with open(questions_file, 'r', encoding='utf-8') as f_in, open(output_file, 'w', encoding='utf-8') as f_out: | ||
for line in tqdm(f_in, desc='Processing'): | ||
if line.startswith(':'): | ||
category = line.split()[1] | ||
else: | ||
words = line.split() | ||
word, cos = model.most_similar(positive=[words[1], words[2]], negative=[words[0]], topn=1)[0] | ||
f_out.write(f"{category} {' '.join(words)} {word} {cos}\n") | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from tqdm import tqdm | ||
|
||
def main(): | ||
input_file = 'questions-words-add.txt' | ||
|
||
with open(input_file, 'r', encoding='utf-8') as f: | ||
sem_cnt = 0 | ||
sem_cor = 0 | ||
syn_cnt = 0 | ||
syn_cor = 0 | ||
|
||
for line in tqdm(f, desc='Processing'): | ||
line = line.split() | ||
if not line[0].startswith('gram'): | ||
sem_cnt += 1 | ||
if line[4] == line[5]: | ||
sem_cor += 1 | ||
else: | ||
syn_cnt += 1 | ||
if line[4] == line[5]: | ||
syn_cor += 1 | ||
|
||
print(f'意味的アナロジー正解率: {sem_cor/sem_cnt:.3f}') | ||
print(f'文法的アナロジー正解率: {syn_cor/syn_cnt:.3f}') | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
|
||
""" | ||
意味的アナロジー正解率: 0.731 | ||
文法的アナロジー正解率: 0.740 | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from gensim.models import KeyedVectors | ||
from tqdm import tqdm | ||
from scipy.stats import spearmanr | ||
|
||
def cos_sim(v1, v2): | ||
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) | ||
|
||
def calc_cos_sim(row, model): | ||
try: | ||
w1v = model[row['Word 1']] | ||
w2v = model[row['Word 2']] | ||
return cos_sim(w1v, w2v) | ||
except KeyError: | ||
return np.nan | ||
|
||
def main(): | ||
model_path = 'GoogleNews-vectors-negative300.bin.gz' | ||
model = KeyedVectors.load_word2vec_format(model_path, binary=True) | ||
|
||
csv_file = 'wordsim353/combined.csv' | ||
combined_df = pd.read_csv(csv_file) | ||
|
||
tqdm.pandas(desc='Calculating cosine similarity') | ||
combined_df['cos_sim'] = combined_df.progress_apply(calc_cos_sim, axis=1, model=model) | ||
|
||
combined_df = combined_df.dropna(subset=['cos_sim']) | ||
|
||
human_scores = combined_df['Human (mean)'].values | ||
cos_sim_scores = combined_df['cos_sim'].values | ||
|
||
spearman_corr, _ = spearmanr(human_scores, cos_sim_scores) | ||
print(f'Spearman correlation: {spearman_corr:.3f}') | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
""" | ||
Spearman correlation: 0.700 | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
from gensim.models import KeyedVectors | ||
from sklearn.cluster import KMeans | ||
|
||
def main(): | ||
# GoogleNews-vectors-negative300.bin.gzの読み込み | ||
model = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True) | ||
|
||
# 国名の取得 | ||
countries = set() | ||
with open('questions-words-add.txt', 'r') as f: | ||
for line in f: | ||
line = line.split() | ||
# capital-common-countries, capital-world, currency, gram6-nationalを用いて国名を取得 | ||
if line[0] in ['capital-common-countries', 'capital-world']: | ||
countries.add(line[2]) | ||
elif line[0] in ['currency', 'gram6-nationality-adjective']: | ||
countries.add(line[1]) | ||
countries = list(countries) | ||
|
||
# 単語ベクトルの取得 | ||
countries_vec = [model[country] for country in countries] | ||
|
||
# k-meansクラスタリング | ||
n_clusters = 5 | ||
kmeans = KMeans(n_clusters=n_clusters, random_state=42) | ||
kmeans.fit(countries_vec) | ||
|
||
# クラスタリング結果の出力 | ||
for i in range(n_clusters): | ||
cluster = np.where(kmeans.labels_ == i)[0] | ||
print(f'Cluster {i}:') | ||
print(', '.join([countries[k] for k in cluster])) | ||
print() | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
""" | ||
Cluster 0: | ||
Denmark, Japan, Austria, Spain, Tuvalu, Malaysia, Belgium, Portugal, Morocco, England, Oman, Malta, Netherlands, Samoa, Sweden, Vietnam, Europe, Nepal, Iceland, Thailand, Italy, Ireland, Greenland, Indonesia, Australia, China, India, Philippines, USA, Germany, Fiji, Laos, Switzerland, Canada, Korea, Taiwan, Finland, Qatar, Bahrain, Liechtenstein, Cambodia, Bhutan, France, Norway, Bangladesh | ||
Cluster 1: | ||
Uganda, Namibia, Nigeria, Rwanda, Niger, Zimbabwe, Mozambique, Mali, Burundi, Algeria, Mauritania, Madagascar, Gabon, Senegal, Tunisia, Kenya, Gambia, Malawi, Liberia, Zambia, Angola, Guinea, Botswana, Ghana | ||
Cluster 2: | ||
Dominica, Colombia, Brazil, Cuba, Peru, Jamaica, Uruguay, Nicaragua, Suriname, Belize, Chile, Ecuador, Honduras, Argentina, Guyana, Venezuela, Mexico, Bahamas | ||
Cluster 3: | ||
Lebanon, Libya, Sudan, Syria, Israel, Pakistan, Iran, Egypt, Afghanistan, Jordan, Iraq, Somalia, Eritrea | ||
Cluster 4: | ||
Croatia, Albania, Latvia, Ukraine, Cyprus, Azerbaijan, Poland, Montenegro, Greece, Russia, Macedonia, Georgia, Kyrgyzstan, Estonia, Uzbekistan, Hungary, Turkmenistan, Turkey, Armenia, Belarus, Lithuania, Moldova, Bulgaria, Tajikistan, Slovenia, Serbia, Kazakhstan, Romania, Slovakia | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import numpy as np | ||
from gensim.models import KeyedVectors | ||
from scipy.cluster.hierarchy import dendrogram, linkage | ||
from matplotlib import pyplot as plt | ||
|
||
def main(): | ||
# GoogleNews-vectors-negative300.bin.gzの読み込み | ||
model = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True) | ||
|
||
# 国名の取得 | ||
countries = set() | ||
with open('questions-words-add.txt', 'r') as f: | ||
for line in f: | ||
line = line.split() | ||
if line[0] in ['capital-common-countries', 'capital-world']: | ||
countries.add(line[2]) | ||
elif line[0] in ['currency', 'gram6-nationality-adjective']: | ||
countries.add(line[1]) | ||
countries = list(countries) | ||
|
||
# 単語ベクトルの取得 | ||
countries_vec = [model[country] for country in countries] | ||
|
||
# Ward法による階層型クラスタリング | ||
Z = linkage(countries_vec, method='ward') | ||
|
||
# デンドログラムの可視化 | ||
plt.figure(figsize=(20, 10)) | ||
dendrogram(Z, labels=countries, leaf_rotation=90, leaf_font_size=8) | ||
plt.tight_layout() | ||
plt.savefig('ward_dendrogram.png') | ||
plt.show() | ||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.