418 lines
19 KiB
Python
418 lines
19 KiB
Python
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow.contrib.rnn import RNNCell
|
|
from tensorflow.python.ops import rnn_cell_impl
|
|
from tensorflow.contrib.data.python.util import nest
|
|
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper \
|
|
import _bahdanau_score, _BaseAttentionMechanism, BahdanauAttention, \
|
|
AttentionWrapperState, AttentionMechanism
|
|
|
|
from .modules import prenet
|
|
|
|
_zero_state_tensors = rnn_cell_impl._zero_state_tensors
|
|
|
|
|
|
|
|
class AttentionWrapper(RNNCell):
|
|
"""Wraps another `RNNCell` with attention.
|
|
"""
|
|
|
|
def __init__(self,
|
|
cell,
|
|
attention_mechanism,
|
|
is_manual_attention,
|
|
manual_alignments,
|
|
attention_layer_size=None,
|
|
alignment_history=False,
|
|
cell_input_fn=None,
|
|
output_attention=True,
|
|
initial_cell_state=None,
|
|
name=None):
|
|
"""Construct the `AttentionWrapper`.
|
|
Args:
|
|
cell: An instance of `RNNCell`.
|
|
attention_mechanism: A list of `AttentionMechanism` instances or a single
|
|
instance.
|
|
attention_layer_size: A list of Python integers or a single Python
|
|
integer, the depth of the attention (output) layer(s). If None
|
|
(default), use the context as attention at each time step. Otherwise,
|
|
feed the context and cell output into the attention layer to generate
|
|
attention at each time step. If attention_mechanism is a list,
|
|
attention_layer_size must be a list of the same length.
|
|
alignment_history: Python boolean, whether to store alignment history
|
|
from all time steps in the final output state (currently stored as a
|
|
time major `TensorArray` on which you must call `stack()`).
|
|
cell_input_fn: (optional) A `callable`. The default is:
|
|
`lambda inputs, attention: array_tf.concat([inputs, attention], -1)`.
|
|
output_attention: Python bool. If `True` (default), the output at each
|
|
time step is the attention value. This is the behavior of Luong-style
|
|
attention mechanisms. If `False`, the output at each time step is
|
|
the output of `cell`. This is the beahvior of Bhadanau-style
|
|
attention mechanisms. In both cases, the `attention` tensor is
|
|
propagated to the next time step via the state and is used there.
|
|
This flag only controls whether the attention mechanism is propagated
|
|
up to the next cell in an RNN stack or to the top RNN output.
|
|
initial_cell_state: The initial state value to use for the cell when
|
|
the user calls `zero_state()`. Note that if this value is provided
|
|
now, and the user uses a `batch_size` argument of `zero_state` which
|
|
does not match the batch size of `initial_cell_state`, proper
|
|
behavior is not guaranteed.
|
|
name: Name to use when creating tf.
|
|
Raises:
|
|
TypeError: `attention_layer_size` is not None and (`attention_mechanism`
|
|
is a list but `attention_layer_size` is not; or vice versa).
|
|
ValueError: if `attention_layer_size` is not None, `attention_mechanism`
|
|
is a list, and its length does not match that of `attention_layer_size`.
|
|
"""
|
|
super(AttentionWrapper, self).__init__(name=name)
|
|
|
|
self.is_manual_attention = is_manual_attention
|
|
self.manual_alignments = manual_alignments
|
|
|
|
if isinstance(attention_mechanism, (list, tuple)):
|
|
self._is_multi = True
|
|
attention_mechanisms = attention_mechanism
|
|
for attention_mechanism in attention_mechanisms:
|
|
if not isinstance(attention_mechanism, AttentionMechanism):
|
|
raise TypeError(
|
|
"attention_mechanism must contain only instances of "
|
|
"AttentionMechanism, saw type: %s"
|
|
% type(attention_mechanism).__name__)
|
|
else:
|
|
self._is_multi = False
|
|
if not isinstance(attention_mechanism, AttentionMechanism):
|
|
raise TypeError(
|
|
"attention_mechanism must be an AttentionMechanism or list of "
|
|
"multiple AttentionMechanism instances, saw type: %s"
|
|
% type(attention_mechanism).__name__)
|
|
attention_mechanisms = (attention_mechanism,)
|
|
|
|
if cell_input_fn is None:
|
|
cell_input_fn = (
|
|
lambda inputs, attention: tf.concat([inputs, attention], -1))
|
|
else:
|
|
if not callable(cell_input_fn):
|
|
raise TypeError(
|
|
"cell_input_fn must be callable, saw type: %s"
|
|
% type(cell_input_fn).__name__)
|
|
|
|
if attention_layer_size is not None:
|
|
attention_layer_sizes = tuple(
|
|
attention_layer_size
|
|
if isinstance(attention_layer_size, (list, tuple))
|
|
else (attention_layer_size,))
|
|
if len(attention_layer_sizes) != len(attention_mechanisms):
|
|
raise ValueError(
|
|
"If provided, attention_layer_size must contain exactly one "
|
|
"integer per attention_mechanism, saw: %d vs %d"
|
|
% (len(attention_layer_sizes), len(attention_mechanisms)))
|
|
self._attention_layers = tuple(
|
|
layers_core.Dense(
|
|
attention_layer_size, name="attention_layer", use_bias=False)
|
|
for attention_layer_size in attention_layer_sizes)
|
|
self._attention_layer_size = sum(attention_layer_sizes)
|
|
else:
|
|
self._attention_layers = None
|
|
self._attention_layer_size = sum(
|
|
attention_mechanism.values.get_shape()[-1].value
|
|
for attention_mechanism in attention_mechanisms)
|
|
|
|
self._cell = cell
|
|
self._attention_mechanisms = attention_mechanisms
|
|
self._cell_input_fn = cell_input_fn
|
|
self._output_attention = output_attention
|
|
self._alignment_history = alignment_history
|
|
with tf.name_scope(name, "AttentionWrapperInit"):
|
|
if initial_cell_state is None:
|
|
self._initial_cell_state = None
|
|
else:
|
|
final_state_tensor = nest.flatten(initial_cell_state)[-1]
|
|
state_batch_size = (
|
|
final_state_tensor.shape[0].value
|
|
or tf.shape(final_state_tensor)[0])
|
|
error_message = (
|
|
"When constructing AttentionWrapper %s: " % self._base_name +
|
|
"Non-matching batch sizes between the memory "
|
|
"(encoder output) and initial_cell_state. Are you using "
|
|
"the BeamSearchDecoder? You may need to tile your initial state "
|
|
"via the tf.contrib.seq2seq.tile_batch function with argument "
|
|
"multiple=beam_width.")
|
|
with tf.control_dependencies(
|
|
self._batch_size_checks(state_batch_size, error_message)):
|
|
self._initial_cell_state = nest.map_structure(
|
|
lambda s: tf.identity(s, name="check_initial_cell_state"),
|
|
initial_cell_state)
|
|
|
|
def _batch_size_checks(self, batch_size, error_message):
|
|
return [tf.assert_equal(batch_size,
|
|
attention_mechanism.batch_size,
|
|
message=error_message)
|
|
for attention_mechanism in self._attention_mechanisms]
|
|
|
|
def _item_or_tuple(self, seq):
|
|
"""Returns `seq` as tuple or the singular element.
|
|
Which is returned is determined by how the AttentionMechanism(s) were passed
|
|
to the constructor.
|
|
Args:
|
|
seq: A non-empty sequence of items or generator.
|
|
Returns:
|
|
Either the values in the sequence as a tuple if AttentionMechanism(s)
|
|
were passed to the constructor as a sequence or the singular element.
|
|
"""
|
|
t = tuple(seq)
|
|
if self._is_multi:
|
|
return t
|
|
else:
|
|
return t[0]
|
|
|
|
@property
|
|
def output_size(self):
|
|
if self._output_attention:
|
|
return self._attention_layer_size
|
|
else:
|
|
return self._cell.output_size
|
|
|
|
@property
|
|
def state_size(self):
|
|
return AttentionWrapperState(
|
|
cell_state=self._cell.state_size,
|
|
time=tf.TensorShape([]),
|
|
attention=self._attention_layer_size,
|
|
alignments=self._item_or_tuple(
|
|
a.alignments_size for a in self._attention_mechanisms),
|
|
alignment_history=self._item_or_tuple(
|
|
() for _ in self._attention_mechanisms)) # sometimes a TensorArray
|
|
|
|
def zero_state(self, batch_size, dtype):
|
|
with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
|
if self._initial_cell_state is not None:
|
|
cell_state = self._initial_cell_state
|
|
else:
|
|
cell_state = self._cell.zero_state(batch_size, dtype)
|
|
error_message = (
|
|
"When calling zero_state of AttentionWrapper %s: " % self._base_name +
|
|
"Non-matching batch sizes between the memory "
|
|
"(encoder output) and the requested batch size. Are you using "
|
|
"the BeamSearchDecoder? If so, make sure your encoder output has "
|
|
"been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
|
|
"the batch_size= argument passed to zero_state is "
|
|
"batch_size * beam_width.")
|
|
with tf.control_dependencies(
|
|
self._batch_size_checks(batch_size, error_message)):
|
|
cell_state = nest.map_structure(
|
|
lambda s: tf.identity(s, name="checked_cell_state"),
|
|
cell_state)
|
|
|
|
return AttentionWrapperState(
|
|
cell_state=cell_state,
|
|
time=tf.zeros([], dtype=tf.int32),
|
|
attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype),
|
|
alignments=self._item_or_tuple(
|
|
attention_mechanism.initial_alignments(batch_size, dtype)
|
|
for attention_mechanism in self._attention_mechanisms),
|
|
alignment_history=self._item_or_tuple(
|
|
tf.TensorArray(dtype=dtype, size=0, dynamic_size=True)
|
|
if self._alignment_history else ()
|
|
for _ in self._attention_mechanisms))
|
|
|
|
def call(self, inputs, state):
|
|
"""Perform a step of attention-wrapped RNN.
|
|
- Step 1: Mix the `inputs` and previous step's `attention` output via
|
|
`cell_input_fn`.
|
|
- Step 2: Call the wrapped `cell` with this input and its previous state.
|
|
- Step 3: Score the cell's output with `attention_mechanism`.
|
|
- Step 4: Calculate the alignments by passing the score through the
|
|
`normalizer`.
|
|
- Step 5: Calculate the context vector as the inner product between the
|
|
alignments and the attention_mechanism's values (memory).
|
|
- Step 6: Calculate the attention output by concatenating the cell output
|
|
and context through the attention layer (a linear layer with
|
|
`attention_layer_size` outputs).
|
|
Args:
|
|
inputs: (Possibly nested tuple of) Tensor, the input at this time step.
|
|
state: An instance of `AttentionWrapperState` containing
|
|
tensors from the previous time step.
|
|
Returns:
|
|
A tuple `(attention_or_cell_output, next_state)`, where:
|
|
- `attention_or_cell_output` depending on `output_attention`.
|
|
- `next_state` is an instance of `AttentionWrapperState`
|
|
containing the state calculated at this time step.
|
|
Raises:
|
|
TypeError: If `state` is not an instance of `AttentionWrapperState`.
|
|
"""
|
|
if not isinstance(state, AttentionWrapperState):
|
|
raise TypeError("Expected state to be instance of AttentionWrapperState. "
|
|
"Received type %s instead." % type(state))
|
|
|
|
# Step 1: Calculate the true inputs to the cell based on the
|
|
# previous attention value.
|
|
cell_inputs = self._cell_input_fn(inputs, state.attention)
|
|
cell_state = state.cell_state
|
|
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
|
|
|
|
cell_batch_size = (
|
|
cell_output.shape[0].value or tf.shape(cell_output)[0])
|
|
error_message = (
|
|
"When applying AttentionWrapper %s: " % self.name +
|
|
"Non-matching batch sizes between the memory "
|
|
"(encoder output) and the query (decoder output). Are you using "
|
|
"the BeamSearchDecoder? You may need to tile your memory input via "
|
|
"the tf.contrib.seq2seq.tile_batch function with argument "
|
|
"multiple=beam_width.")
|
|
with tf.control_dependencies(
|
|
self._batch_size_checks(cell_batch_size, error_message)):
|
|
cell_output = tf.identity(
|
|
cell_output, name="checked_cell_output")
|
|
|
|
if self._is_multi:
|
|
previous_alignments = state.alignments
|
|
previous_alignment_history = state.alignment_history
|
|
else:
|
|
previous_alignments = [state.alignments]
|
|
previous_alignment_history = [state.alignment_history]
|
|
|
|
all_alignments = []
|
|
all_attentions = []
|
|
all_histories = []
|
|
|
|
for i, attention_mechanism in enumerate(self._attention_mechanisms):
|
|
attention, alignments = _compute_attention(
|
|
attention_mechanism, cell_output, previous_alignments[i],
|
|
self._attention_layers[i] if self._attention_layers else None,
|
|
self.is_manual_attention, self.manual_alignments, state.time)
|
|
|
|
alignment_history = previous_alignment_history[i].write(
|
|
state.time, alignments) if self._alignment_history else ()
|
|
|
|
all_alignments.append(alignments)
|
|
all_histories.append(alignment_history)
|
|
all_attentions.append(attention)
|
|
|
|
attention = tf.concat(all_attentions, 1)
|
|
next_state = AttentionWrapperState(
|
|
time=state.time + 1,
|
|
cell_state=next_cell_state,
|
|
attention=attention,
|
|
alignments=self._item_or_tuple(all_alignments),
|
|
alignment_history=self._item_or_tuple(all_histories))
|
|
|
|
if self._output_attention:
|
|
return attention, next_state
|
|
else:
|
|
return cell_output, next_state
|
|
|
|
def _compute_attention(
|
|
attention_mechanism, cell_output, previous_alignments,
|
|
attention_layer, is_manual_attention, manual_alignments, time):
|
|
|
|
computed_alignments = attention_mechanism(
|
|
cell_output, previous_alignments=previous_alignments)
|
|
batch_size, max_time = \
|
|
tf.shape(computed_alignments)[0], tf.shape(computed_alignments)[1]
|
|
|
|
alignments = tf.cond(
|
|
is_manual_attention,
|
|
lambda: manual_alignments[:, time, :],
|
|
lambda: computed_alignments,
|
|
)
|
|
|
|
#alignments = tf.one_hot(tf.zeros((batch_size,), dtype=tf.int32), max_time, dtype=tf.float32)
|
|
|
|
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
|
|
expanded_alignments = tf.expand_dims(alignments, 1)
|
|
|
|
# Context is the inner product of alignments and values along the
|
|
# memory time dimension.
|
|
# alignments shape is
|
|
# [batch_size, 1, memory_time]
|
|
# attention_mechanism.values shape is
|
|
# [batch_size, memory_time, memory_size]
|
|
# the batched matmul is over memory_time, so the output shape is
|
|
# [batch_size, 1, memory_size].
|
|
# we then squeeze out the singleton dim.
|
|
context = tf.matmul(expanded_alignments, attention_mechanism.values)
|
|
context = tf.squeeze(context, [1])
|
|
|
|
if attention_layer is not None:
|
|
attention = attention_layer(tf.concat([cell_output, context], 1))
|
|
else:
|
|
attention = context
|
|
|
|
return attention, alignments
|
|
|
|
|
|
class DecoderPrenetWrapper(RNNCell):
|
|
'''Runs RNN inputs through a prenet before sending them to the cell.'''
|
|
def __init__(
|
|
self, cell, embed_to_concat,
|
|
is_training, prenet_sizes, dropout_prob):
|
|
|
|
super(DecoderPrenetWrapper, self).__init__()
|
|
self._is_training = is_training
|
|
|
|
self._cell = cell
|
|
self._embed_to_concat = embed_to_concat
|
|
|
|
self.prenet_sizes = prenet_sizes
|
|
self.dropout_prob = dropout_prob
|
|
|
|
@property
|
|
def state_size(self):
|
|
return self._cell.state_size
|
|
|
|
@property
|
|
def output_size(self):
|
|
return self._cell.output_size
|
|
|
|
def call(self, inputs, state):
|
|
prenet_out = prenet(
|
|
inputs, self._is_training,
|
|
self.prenet_sizes, self.dropout_prob, scope='decoder_prenet')
|
|
|
|
if self._embed_to_concat is not None:
|
|
concat_out = tf.concat(
|
|
[prenet_out, self._embed_to_concat],
|
|
axis=-1, name='speaker_concat')
|
|
return self._cell(concat_out, state)
|
|
else:
|
|
return self._cell(prenet_out, state)
|
|
|
|
def zero_state(self, batch_size, dtype):
|
|
return self._cell.zero_state(batch_size, dtype)
|
|
|
|
|
|
|
|
class ConcatOutputAndAttentionWrapper(RNNCell):
|
|
'''Concatenates RNN cell output with the attention context vector.
|
|
|
|
This is expected to wrap a cell wrapped with an AttentionWrapper constructed with
|
|
attention_layer_size=None and output_attention=False. Such a cell's state will include an
|
|
"attention" field that is the context vector.
|
|
'''
|
|
def __init__(self, cell, embed_to_concat):
|
|
super(ConcatOutputAndAttentionWrapper, self).__init__()
|
|
self._cell = cell
|
|
self._embed_to_concat = embed_to_concat
|
|
|
|
@property
|
|
def state_size(self):
|
|
return self._cell.state_size
|
|
|
|
@property
|
|
def output_size(self):
|
|
return self._cell.output_size + self._cell.state_size.attention
|
|
|
|
def call(self, inputs, state):
|
|
output, res_state = self._cell(inputs, state)
|
|
|
|
if self._embed_to_concat is not None:
|
|
tensors = [
|
|
output, res_state.attention,
|
|
self._embed_to_concat,
|
|
]
|
|
return tf.concat(tensors, axis=-1), res_state
|
|
else:
|
|
return tf.concat([output, res_state.attention], axis=-1), res_state
|
|
|
|
def zero_state(self, batch_size, dtype):
|
|
return self._cell.zero_state(batch_size, dtype)
|