-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconvert-closure-embeddings-to-word-embeddings.py
80 lines (76 loc) · 4.17 KB
/
convert-closure-embeddings-to-word-embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import re
import time
import io
import sys
import argparse
from collections import defaultdict
import gzip
# each line corresponds to the embedding of one cluster/closure of words which are translationally equivalent
# from one or more languages. An example cluster of two words is "en:dog_|_fr:chien". If such cluster appears
# in the input file, the output file would have two separate lines, one for "en:dog", and another for
# "fr:chien" with identical embeddings (i.e., the embedding of the cluster "en:dog_|_fr:chien" in the input
# file.
# parse/validate arguments
argparser = argparse.ArgumentParser()
argparser.add_argument("-i", "--input-filename", required=True, help=
" An embeddings file (word2vec format) where the first column is a cluster id.")
argparser.add_argument("-o", "--output-filename", required=True, help=
" An embeddings file (word2vec format) where the first column corresponds to individual words.")
argparser.add_argument("-w", "--word-clusters", required=True, help=
" each line is formatted as 'cluster_id\tsurface_form\tfreq', indicating that the word 'surface_form' is one of the words in cluster 'cluster_id'. freq is obsolete and only used for compatibility with percy liang's cluster outputs")
args = argparser.parse_args()
cluster_to_words = defaultdict(list)
with gzip.open(args.word_clusters) if args.word_clusters.endswith('.gz') else open(args.word_clusters) as word_clusters_file:
# read word clusters
for line in word_clusters_file:
try:
line = line.decode('utf8')
except UnicodeDecodeError:
print 'WARNING: utf8 decoding error for the line:', line, '. Will skip this one.'
continue
splits = line.strip().split("\t")
assert(len(splits) == 3)
qualified_word, cluster_id = splits[1], splits[0]
cluster_to_words[cluster_id].append(qualified_word)
print '{} clusters read from {}'.format(len(cluster_to_words), args.word_clusters)
# stream
with gzip.open(args.output_filename, mode='w') if args.output_filename.endswith('.gz') else open(args.output_filename, mode='w') as output_file, gzip.open(args.input_filename, mode='r') if args.input_filename.endswith('.gz') else open(args.input_filename, mode='r') as input_file:
# initialize
unique_words = set()
embedding_dimensionality = -1
# the first line of the embeddings file is metadata (word2vec format). copy it as is.
output_file.write(input_file.readline())
# for each word cluster
for line in input_file:
try:
line = line.decode('utf8')
except UnicodeDecodeError:
print 'WARNING: utf8 decoding error for the line:', line, '. Will skip this one.'
continue
# read the cluster string and its embedding
line_splits = line.strip().split(' ')
# check embedding size
if embedding_dimensionality == -1:
embedding_dimensionality = len(line_splits) - 1
if len(line_splits) == 2:
# skip (another?) metadata line
continue
if embedding_dimensionality != len(line_splits) - 1:
print 'dimensionality problem: I thought the dimesnionality is {}, but this line has {} splits:\n{}'.format(embedding_dimensionality, len(line_splits), line_splits)
assert False
# merge embedding values back into a utf8-encoded string
embedding_string = u' '.join(line_splits[1:]).encode('utf8')
# split the cluster string into words
cluster = line_splits[0]
if cluster == '</s>': continue
if cluster not in cluster_to_words:
print 'FATAL: cluster {} not found in {}. Will die!'.format(cluster, args.word_clusters)
assert(cluster in cluster_to_words)
words = cluster_to_words[cluster]
for word in words:
if word in unique_words:
print u"WARNING: '{}' appears twice in input embeddings file. Will let go because the embeddings were apparently messed up. Please consider rebuilding your embeddings such that the cluster strings are not cut off. word2vec cuts off words of length > 1000 by default.".format(word)
out_line = '{} {}\n'.format(word.encode('utf8'), embedding_string)
# when -g is specified, don't write this line if the cluster is of size 1
output_file.write(out_line)
unique_words |= set(word)