|
| 1 | +import tensorflow as tf |
| 2 | + |
| 3 | +from tensorflow_addons.utils.types import TensorLike |
| 4 | +from typeguard import typechecked |
| 5 | +from typing import Tuple |
| 6 | + |
| 7 | + |
| 8 | +# original code taken from |
| 9 | +# https://github.com/tensorflow/addons/blob/master/tensorflow_addons/text/crf.py |
| 10 | +# (modified to our neeeds) |
| 11 | + |
| 12 | + |
| 13 | +class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell): |
| 14 | + """Computes the forward decoding in a linear-chain CRF.""" |
| 15 | + |
| 16 | + @typechecked |
| 17 | + def __init__(self, transition_params: TensorLike, **kwargs) -> None: |
| 18 | + """Initialize the CrfDecodeForwardRnnCell. |
| 19 | +
|
| 20 | + Args: |
| 21 | + transition_params: A [num_tags, num_tags] matrix of binary |
| 22 | + potentials. This matrix is expanded into a |
| 23 | + [1, num_tags, num_tags] in preparation for the broadcast |
| 24 | + summation occurring within the cell. |
| 25 | + """ |
| 26 | + super().__init__(**kwargs) |
| 27 | + self._transition_params = tf.expand_dims(transition_params, 0) |
| 28 | + self._num_tags = transition_params.shape[0] |
| 29 | + |
| 30 | + @property |
| 31 | + def state_size(self) -> int: |
| 32 | + return self._num_tags |
| 33 | + |
| 34 | + @property |
| 35 | + def output_size(self) -> int: |
| 36 | + return self._num_tags |
| 37 | + |
| 38 | + def build(self, input_shape): |
| 39 | + super().build(input_shape) |
| 40 | + |
| 41 | + def call( |
| 42 | + self, inputs: TensorLike, state: TensorLike |
| 43 | + ) -> Tuple[tf.Tensor, tf.Tensor]: |
| 44 | + """Build the CrfDecodeForwardRnnCell. |
| 45 | +
|
| 46 | + Args: |
| 47 | + inputs: A [batch_size, num_tags] matrix of unary potentials. |
| 48 | + state: A [batch_size, num_tags] matrix containing the previous step's |
| 49 | + score values. |
| 50 | +
|
| 51 | + Returns: |
| 52 | + output: A [batch_size, num_tags * 2] matrix of backpointers and scores. |
| 53 | + new_state: A [batch_size, num_tags] matrix of new score values. |
| 54 | + """ |
| 55 | + state = tf.expand_dims(state[0], 2) |
| 56 | + transition_scores = state + self._transition_params |
| 57 | + new_state = inputs + tf.reduce_max(transition_scores, [1]) |
| 58 | + |
| 59 | + backpointers = tf.argmax(transition_scores, 1) |
| 60 | + backpointers = tf.cast(backpointers, tf.float32) |
| 61 | + |
| 62 | + # apply softmax to transition_scores to get scores in range from 0 to 1 |
| 63 | + scores = tf.reduce_max(tf.nn.softmax(transition_scores, axis=1), [1]) |
| 64 | + |
| 65 | + # In the RNN implementation only the first value that is returned from a cell |
| 66 | + # is kept throughout the RNN, so that you will have the values from each time |
| 67 | + # step in the final output. As we need the backpointers as well as the scores |
| 68 | + # for each time step, we concatenate them. |
| 69 | + return tf.concat([backpointers, scores], axis=1), new_state |
| 70 | + |
| 71 | + |
| 72 | +def crf_decode_forward( |
| 73 | + inputs: TensorLike, |
| 74 | + state: TensorLike, |
| 75 | + transition_params: TensorLike, |
| 76 | + sequence_lengths: TensorLike, |
| 77 | +) -> Tuple[tf.Tensor, tf.Tensor]: |
| 78 | + """Computes forward decoding in a linear-chain CRF. |
| 79 | +
|
| 80 | + Args: |
| 81 | + inputs: A [batch_size, num_tags] matrix of unary potentials. |
| 82 | + state: A [batch_size, num_tags] matrix containing the previous step's |
| 83 | + score values. |
| 84 | + transition_params: A [num_tags, num_tags] matrix of binary potentials. |
| 85 | + sequence_lengths: A [batch_size] vector of true sequence lengths. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + output: A [batch_size, num_tags * 2] matrix of backpointers and scores. |
| 89 | + new_state: A [batch_size, num_tags] matrix of new score values. |
| 90 | + """ |
| 91 | + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) |
| 92 | + mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) |
| 93 | + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) |
| 94 | + crf_fwd_layer = tf.keras.layers.RNN( |
| 95 | + crf_fwd_cell, return_sequences=True, return_state=True |
| 96 | + ) |
| 97 | + return crf_fwd_layer(inputs, state, mask=mask) |
| 98 | + |
| 99 | + |
| 100 | +def crf_decode_backward( |
| 101 | + backpointers: TensorLike, scores: TensorLike, state: TensorLike |
| 102 | +) -> Tuple[tf.Tensor, tf.Tensor]: |
| 103 | + """Computes backward decoding in a linear-chain CRF. |
| 104 | +
|
| 105 | + Args: |
| 106 | + backpointers: A [batch_size, num_tags] matrix of backpointer of next step |
| 107 | + (in time order). |
| 108 | + scores: A [batch_size, num_tags] matrix of scores of next step (in time order). |
| 109 | + state: A [batch_size, 1] matrix of tag index of next step. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + new_tags: A [batch_size, num_tags] tensor containing the new tag indices. |
| 113 | + new_scores: A [batch_size, num_tags] tensor containing the new score values. |
| 114 | + """ |
| 115 | + backpointers = tf.transpose(backpointers, [1, 0, 2]) |
| 116 | + scores = tf.transpose(scores, [1, 0, 2]) |
| 117 | + |
| 118 | + def _scan_fn(_state: TensorLike, _inputs: TensorLike) -> tf.Tensor: |
| 119 | + _state = tf.cast(tf.squeeze(_state, axis=[1]), dtype=tf.int32) |
| 120 | + idxs = tf.stack([tf.range(tf.shape(_inputs)[0]), _state], axis=1) |
| 121 | + return tf.expand_dims(tf.gather_nd(_inputs, idxs), axis=-1) |
| 122 | + |
| 123 | + output_tags = tf.scan(_scan_fn, backpointers, state) |
| 124 | + # the dtype of the input parameters of tf.scan need to match |
| 125 | + # convert state to float32 to match the type of scores |
| 126 | + state = tf.cast(state, dtype=tf.float32) |
| 127 | + output_scores = tf.scan(_scan_fn, scores, state) |
| 128 | + |
| 129 | + return tf.transpose(output_tags, [1, 0, 2]), tf.transpose(output_scores, [1, 0, 2]) |
| 130 | + |
| 131 | + |
| 132 | +def crf_decode( |
| 133 | + potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike |
| 134 | +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: |
| 135 | + """Decode the highest scoring sequence of tags. |
| 136 | +
|
| 137 | + Args: |
| 138 | + potentials: A [batch_size, max_seq_len, num_tags] tensor of |
| 139 | + unary potentials. |
| 140 | + transition_params: A [num_tags, num_tags] matrix of |
| 141 | + binary potentials. |
| 142 | + sequence_length: A [batch_size] vector of true sequence lengths. |
| 143 | +
|
| 144 | + Returns: |
| 145 | + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. |
| 146 | + Contains the highest scoring tag indices. |
| 147 | + decode_scores: A [batch_size, max_seq_len] matrix, containing the score of |
| 148 | + `decode_tags`. |
| 149 | + best_score: A [batch_size] vector, containing the best score of `decode_tags`. |
| 150 | + """ |
| 151 | + sequence_length = tf.cast(sequence_length, dtype=tf.int32) |
| 152 | + |
| 153 | + # If max_seq_len is 1, we skip the algorithm and simply return the |
| 154 | + # argmax tag and the max activation. |
| 155 | + def _single_seq_fn(): |
| 156 | + decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32) |
| 157 | + decode_scores = tf.reduce_max(tf.nn.softmax(potentials, axis=2), axis=2) |
| 158 | + best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1]) |
| 159 | + return decode_tags, decode_scores, best_score |
| 160 | + |
| 161 | + def _multi_seq_fn(): |
| 162 | + # Computes forward decoding. Get last score and backpointers. |
| 163 | + initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) |
| 164 | + initial_state = tf.squeeze(initial_state, axis=[1]) |
| 165 | + inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) |
| 166 | + |
| 167 | + sequence_length_less_one = tf.maximum( |
| 168 | + tf.constant(0, dtype=tf.int32), sequence_length - 1 |
| 169 | + ) |
| 170 | + |
| 171 | + output, last_score = crf_decode_forward( |
| 172 | + inputs, initial_state, transition_params, sequence_length_less_one |
| 173 | + ) |
| 174 | + |
| 175 | + # output is a matrix of size [batch-size, max-seq-length, num-tags * 2] |
| 176 | + # split the matrix on axis 2 to get the backpointers and scores, which are |
| 177 | + # both of size [batch-size, max-seq-length, num-tags] |
| 178 | + backpointers, scores = tf.split(output, 2, axis=2) |
| 179 | + |
| 180 | + backpointers = tf.cast(backpointers, dtype=tf.int32) |
| 181 | + backpointers = tf.reverse_sequence( |
| 182 | + backpointers, sequence_length_less_one, seq_axis=1 |
| 183 | + ) |
| 184 | + |
| 185 | + scores = tf.reverse_sequence(scores, sequence_length_less_one, seq_axis=1) |
| 186 | + |
| 187 | + initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) |
| 188 | + initial_state = tf.expand_dims(initial_state, axis=-1) |
| 189 | + |
| 190 | + initial_score = tf.reduce_max(tf.nn.softmax(last_score, axis=1), axis=[1]) |
| 191 | + initial_score = tf.expand_dims(initial_score, axis=-1) |
| 192 | + |
| 193 | + decode_tags, decode_scores = crf_decode_backward( |
| 194 | + backpointers, scores, initial_state |
| 195 | + ) |
| 196 | + |
| 197 | + decode_tags = tf.squeeze(decode_tags, axis=[2]) |
| 198 | + decode_tags = tf.concat([initial_state, decode_tags], axis=1) |
| 199 | + decode_tags = tf.reverse_sequence(decode_tags, sequence_length, seq_axis=1) |
| 200 | + |
| 201 | + decode_scores = tf.squeeze(decode_scores, axis=[2]) |
| 202 | + decode_scores = tf.concat([initial_score, decode_scores], axis=1) |
| 203 | + decode_scores = tf.reverse_sequence(decode_scores, sequence_length, seq_axis=1) |
| 204 | + |
| 205 | + best_score = tf.reduce_max(last_score, axis=1) |
| 206 | + |
| 207 | + return decode_tags, decode_scores, best_score |
| 208 | + |
| 209 | + if potentials.shape[1] is not None: |
| 210 | + # shape is statically know, so we just execute |
| 211 | + # the appropriate code path |
| 212 | + if potentials.shape[1] == 1: |
| 213 | + return _single_seq_fn() |
| 214 | + |
| 215 | + return _multi_seq_fn() |
| 216 | + |
| 217 | + return tf.cond(tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn) |
0 commit comments