multi-speaker-tacotron-tens.../models/tacotron.py

344 lines
16 KiB
Python
Raw Normal View History

2017-10-15 16:00:17 +02:00
# Code based on https://github.com/keithito/tacotron/blob/master/models/tacotron.py
import numpy as np
import tensorflow as tf
from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, BahdanauMonotonicAttention
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper
from utils.infolog import log
from text.symbols import symbols
from .modules import *
from .helpers import TacoTestHelper, TacoTrainingHelper
from .rnn_wrappers import AttentionWrapper, DecoderPrenetWrapper, ConcatOutputAndAttentionWrapper
class Tacotron():
def __init__(self, hparams):
self._hparams = hparams
def initialize(
self, inputs, input_lengths, num_speakers, speaker_id,
mel_targets=None, linear_targets=None, loss_coeff=None,
rnn_decoder_test_mode=False, is_randomly_initialized=False,
):
is_training = linear_targets is not None
self.is_randomly_initialized = is_randomly_initialized
with tf.variable_scope('inference') as scope:
hp = self._hparams
batch_size = tf.shape(inputs)[0]
# Embeddings
char_embed_table = tf.get_variable(
'embedding', [len(symbols), hp.embedding_size], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.5))
# [N, T_in, embedding_size]
char_embedded_inputs = \
tf.nn.embedding_lookup(char_embed_table, inputs)
self.num_speakers = num_speakers
if self.num_speakers > 1:
if hp.speaker_embedding_size != 1:
speaker_embed_table = tf.get_variable(
'speaker_embedding',
[self.num_speakers, hp.speaker_embedding_size], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.5))
# [N, T_in, speaker_embedding_size]
speaker_embed = tf.nn.embedding_lookup(speaker_embed_table, speaker_id)
if hp.model_type == 'deepvoice':
if hp.speaker_embedding_size == 1:
before_highway = get_embed(
speaker_id, self.num_speakers,
hp.enc_prenet_sizes[-1], "before_highway")
encoder_rnn_init_state = get_embed(
speaker_id, self.num_speakers,
hp.enc_rnn_size * 2, "encoder_rnn_init_state")
attention_rnn_init_state = get_embed(
speaker_id, self.num_speakers,
hp.attention_state_size, "attention_rnn_init_state")
decoder_rnn_init_states = [get_embed(
speaker_id, self.num_speakers,
hp.dec_rnn_size, "decoder_rnn_init_states{}".format(idx + 1)) \
for idx in range(hp.dec_layer_num)]
else:
deep_dense = lambda x, dim: \
tf.layers.dense(x, dim, activation=tf.nn.softsign)
before_highway = deep_dense(
speaker_embed, hp.enc_prenet_sizes[-1])
encoder_rnn_init_state = deep_dense(
speaker_embed, hp.enc_rnn_size * 2)
attention_rnn_init_state = deep_dense(
speaker_embed, hp.attention_state_size)
decoder_rnn_init_states = [deep_dense(
speaker_embed, hp.dec_rnn_size) for _ in range(hp.dec_layer_num)]
speaker_embed = None # deepvoice does not use speaker_embed directly
elif hp.model_type == 'simple':
before_highway = None
encoder_rnn_init_state = None
attention_rnn_init_state = None
decoder_rnn_init_states = None
else:
raise Exception(" [!] Unkown multi-speaker model type: {}".format(hp.model_type))
else:
speaker_embed = None
before_highway = None
encoder_rnn_init_state = None
attention_rnn_init_state = None
decoder_rnn_init_states = None
##############
# Encoder
##############
# [N, T_in, enc_prenet_sizes[-1]]
prenet_outputs = prenet(char_embedded_inputs, is_training,
hp.enc_prenet_sizes, hp.dropout_prob,
scope='prenet')
encoder_outputs = cbhg(
prenet_outputs, input_lengths, is_training,
hp.enc_bank_size, hp.enc_bank_channel_size,
hp.enc_maxpool_width, hp.enc_highway_depth, hp.enc_rnn_size,
hp.enc_proj_sizes, hp.enc_proj_width,
scope="encoder_cbhg",
before_highway=before_highway,
encoder_rnn_init_state=encoder_rnn_init_state)
##############
# Attention
##############
# For manaul control of attention
self.is_manual_attention = tf.placeholder(
tf.bool, shape=(), name='is_manual_attention',
)
self.manual_alignments = tf.placeholder(
tf.float32, shape=[None, None, None], name="manual_alignments",
)
dec_prenet_outputs = DecoderPrenetWrapper(
GRUCell(hp.attention_state_size),
speaker_embed,
is_training, hp.dec_prenet_sizes, hp.dropout_prob)
if hp.attention_type == 'bah_mon':
attention_mechanism = BahdanauMonotonicAttention(
hp.attention_size, encoder_outputs)
elif hp.attention_type == 'bah_norm':
attention_mechanism = BahdanauAttention(
hp.attention_size, encoder_outputs, normalize=True)
elif hp.attention_type == 'luong_scaled':
attention_mechanism = LuongAttention(
hp.attention_size, encoder_outputs, scale=True)
elif hp.attention_type == 'luong':
attention_mechanism = LuongAttention(
hp.attention_size, encoder_outputs)
elif hp.attention_type == 'bah':
attention_mechanism = BahdanauAttention(
hp.attention_size, encoder_outputs)
elif hp.attention_type.startswith('ntm2'):
shift_width = int(hp.attention_type.split('-')[-1])
attention_mechanism = NTMAttention2(
hp.attention_size, encoder_outputs, shift_width=shift_width)
else:
raise Exception(" [!] Unkown attention type: {}".format(hp.attention_type))
attention_cell = AttentionWrapper(
dec_prenet_outputs,
attention_mechanism,
self.is_manual_attention,
self.manual_alignments,
initial_cell_state=attention_rnn_init_state,
alignment_history=True,
output_attention=False
)
# Concatenate attention context vector and RNN cell output into a 512D vector.
# [N, T_in, attention_size+attention_state_size]
concat_cell = ConcatOutputAndAttentionWrapper(
attention_cell, embed_to_concat=speaker_embed)
# Decoder (layers specified bottom to top):
cells = [OutputProjectionWrapper(concat_cell, hp.dec_rnn_size)]
for _ in range(hp.dec_layer_num):
cells.append(ResidualWrapper(GRUCell(hp.dec_rnn_size)))
# [N, T_in, 256]
decoder_cell = MultiRNNCell(cells, state_is_tuple=True)
# Project onto r mel spectrograms (predict r outputs at each RNN step):
output_cell = OutputProjectionWrapper(
decoder_cell, hp.num_mels * hp.reduction_factor)
decoder_init_state = output_cell.zero_state(
batch_size=batch_size, dtype=tf.float32)
if hp.model_type == "deepvoice":
# decoder_init_state[0] : AttentionWrapperState
# = cell_state + attention + time + alignments + alignment_history
# decoder_init_state[0][0] = attention_rnn_init_state (already applied)
decoder_init_state = list(decoder_init_state)
for idx, cell in enumerate(decoder_rnn_init_states):
shape1 = decoder_init_state[idx + 1].get_shape().as_list()
shape2 = cell.get_shape().as_list()
if shape1 != shape2:
raise Exception(" [!] Shape {} and {} should be equal". \
format(shape1, shape2))
decoder_init_state[idx + 1] = cell
decoder_init_state = tuple(decoder_init_state)
if is_training:
helper = TacoTrainingHelper(
inputs, mel_targets, hp.num_mels, hp.reduction_factor,
rnn_decoder_test_mode)
else:
helper = TacoTestHelper(
batch_size, hp.num_mels, hp.reduction_factor)
(decoder_outputs, _), final_decoder_state, _ = \
tf.contrib.seq2seq.dynamic_decode(
BasicDecoder(output_cell, helper, decoder_init_state),
maximum_iterations=hp.max_iters)
# [N, T_out, M]
mel_outputs = tf.reshape(
decoder_outputs, [batch_size, -1, hp.num_mels])
# Add post-processing CBHG:
# [N, T_out, 256]
#post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training)
post_outputs = cbhg(
mel_outputs, None, is_training,
hp.post_bank_size, hp.post_bank_channel_size,
hp.post_maxpool_width, hp.post_highway_depth, hp.post_rnn_size,
hp.post_proj_sizes, hp.post_proj_width,
scope='post_cbhg')
if speaker_embed is not None and hp.model_type == 'simple':
expanded_speaker_emb = tf.expand_dims(speaker_embed, [1])
tiled_speaker_embedding = tf.tile(
expanded_speaker_emb, [1, tf.shape(post_outputs)[1], 1])
# [N, T_out, 256 + alpha]
post_outputs = \
tf.concat([tiled_speaker_embedding, post_outputs], axis=-1)
linear_outputs = tf.layers.dense(post_outputs, hp.num_freq) # [N, T_out, F]
# Grab alignments from the final decoder state:
alignments = tf.transpose(
final_decoder_state[0].alignment_history.stack(), [1, 2, 0])
self.inputs = inputs
self.speaker_id = speaker_id
self.input_lengths = input_lengths
self.loss_coeff = loss_coeff
self.mel_outputs = mel_outputs
self.linear_outputs = linear_outputs
self.alignments = alignments
self.mel_targets = mel_targets
self.linear_targets = linear_targets
self.final_decoder_state = final_decoder_state
log('='*40)
log(' model_type: %s' % hp.model_type)
log('='*40)
log('Initialized Tacotron model. Dimensions: ')
log(' embedding: %d' % char_embedded_inputs.shape[-1])
if speaker_embed is not None:
log(' speaker embedding: %d' % speaker_embed.shape[-1])
else:
log(' speaker embedding: None')
log(' prenet out: %d' % prenet_outputs.shape[-1])
log(' encoder out: %d' % encoder_outputs.shape[-1])
log(' attention out: %d' % attention_cell.output_size)
log(' concat attn & out: %d' % concat_cell.output_size)
log(' decoder cell out: %d' % decoder_cell.output_size)
log(' decoder out (%d frames): %d' % (hp.reduction_factor, decoder_outputs.shape[-1]))
log(' decoder out (1 frame): %d' % mel_outputs.shape[-1])
log(' postnet out: %d' % post_outputs.shape[-1])
log(' linear out: %d' % linear_outputs.shape[-1])
def add_loss(self):
'''Adds loss to the model. Sets "loss" field. initialize must have been called.'''
with tf.variable_scope('loss') as scope:
hp = self._hparams
mel_loss = tf.abs(self.mel_targets - self.mel_outputs)
l1 = tf.abs(self.linear_targets - self.linear_outputs)
expanded_loss_coeff = tf.expand_dims(
tf.expand_dims(self.loss_coeff, [-1]), [-1])
if hp.prioritize_loss:
# Prioritize loss for frequencies.
upper_priority_freq = int(5000 / (hp.sample_rate * 0.5) * hp.num_freq)
lower_priority_freq = int(165 / (hp.sample_rate * 0.5) * hp.num_freq)
l1_priority= l1[:,:,lower_priority_freq:upper_priority_freq]
self.loss = tf.reduce_mean(mel_loss * expanded_loss_coeff) + \
0.5 * tf.reduce_mean(l1 * expanded_loss_coeff) + \
0.5 * tf.reduce_mean(l1_priority * expanded_loss_coeff)
self.linear_loss = tf.reduce_mean(
0.5 * (tf.reduce_mean(l1) + tf.reduce_mean(l1_priority)))
else:
self.loss = tf.reduce_mean(mel_loss * expanded_loss_coeff) + \
tf.reduce_mean(l1 * expanded_loss_coeff)
self.linear_loss = tf.reduce_mean(l1)
self.mel_loss = tf.reduce_mean(mel_loss)
self.loss_without_coeff = self.mel_loss + self.linear_loss
def add_optimizer(self, global_step):
'''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss must have been called.
Args:
global_step: int32 scalar Tensor representing current global step in training
'''
with tf.variable_scope('optimizer') as scope:
hp = self._hparams
step = tf.cast(global_step + 1, dtype=tf.float32)
if hp.decay_learning_rate_mode == 0:
if self.is_randomly_initialized:
warmup_steps = 4000.0
else:
warmup_steps = 40000.0
self.learning_rate = hp.initial_learning_rate * warmup_steps**0.5 * \
tf.minimum(step * warmup_steps**-1.5, step**-0.5)
elif hp.decay_learning_rate_mode == 1:
self.learning_rate = hp.initial_learning_rate * \
tf.train.exponential_decay(1., step, 3000, 0.95)
optimizer = tf.train.AdamOptimizer(self.learning_rate, hp.adam_beta1, hp.adam_beta2)
gradients, variables = zip(*optimizer.compute_gradients(self.loss))
self.gradients = gradients
clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
# Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See:
# https://github.com/tensorflow/tensorflow/issues/1122
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables),
global_step=global_step)
def get_dummy_feed_dict(self):
feed_dict = {
self.is_manual_attention: False,
self.manual_alignments: np.zeros([1, 1, 1]),
}
return feed_dict