multi-speaker-tacotron-tens.../synthesizer.py
2017-10-16 16:41:44 +09:00

390 lines
14 KiB
Python

import io
import os
import re
import librosa
import argparse
import numpy as np
from glob import glob
from tqdm import tqdm
import tensorflow as tf
from functools import partial
from hparams import hparams
from models import create_model, get_most_recent_checkpoint
from audio import save_audio, inv_spectrogram, inv_preemphasis, \
inv_spectrogram_tensorflow
from utils import plot, PARAMS_NAME, load_json, load_hparams, \
add_prefix, add_postfix, get_time, parallel_run, makedirs
from text.korean import tokenize
from text import text_to_sequence, sequence_to_text
class Synthesizer(object):
def close(self):
tf.reset_default_graph()
self.sess.close()
def load(self, checkpoint_path, num_speakers=2, checkpoint_step=None, model_name='tacotron'):
self.num_speakers = num_speakers
if os.path.isdir(checkpoint_path):
load_path = checkpoint_path
checkpoint_path = get_most_recent_checkpoint(checkpoint_path, checkpoint_step)
else:
load_path = os.path.dirname(checkpoint_path)
print('Constructing model: %s' % model_name)
inputs = tf.placeholder(tf.int32, [None, None], 'inputs')
input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
batch_size = tf.shape(inputs)[0]
speaker_id = tf.placeholder_with_default(
tf.zeros([batch_size], dtype=tf.int32), [None], 'speaker_id')
load_hparams(hparams, load_path)
with tf.variable_scope('model') as scope:
self.model = create_model(hparams)
self.model.initialize(
inputs, input_lengths,
self.num_speakers, speaker_id)
self.wav_output = \
inv_spectrogram_tensorflow(self.model.linear_outputs)
print('Loading checkpoint: %s' % checkpoint_path)
sess_config = tf.ConfigProto(
allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=2)
sess_config.gpu_options.allow_growth = True
self.sess = tf.Session(config=sess_config)
self.sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(self.sess, checkpoint_path)
def synthesize(self,
texts=None, tokens=None,
base_path=None, paths=None, speaker_ids=None,
start_of_sentence=None, end_of_sentence=True,
pre_word_num=0, post_word_num=0,
pre_surplus_idx=0, post_surplus_idx=1,
use_short_concat=False,
manual_attention_mode=0,
base_alignment_path=None,
librosa_trim=False,
attention_trim=True):
# Possible inputs:
# 1) text=text
# 2) text=texts
# 3) tokens=tokens, texts=texts # use texts as guide
if type(texts) == str:
texts = [texts]
if texts is not None and tokens is None:
sequences = [text_to_sequence(text) for text in texts]
elif tokens is not None:
sequences = tokens
if paths is None:
paths = [None] * len(sequences)
if texts is None:
texts = [None] * len(sequences)
time_str = get_time()
def plot_and_save_parallel(
wavs, alignments, use_manual_attention):
items = list(enumerate(zip(
wavs, alignments, paths, texts, sequences)))
fn = partial(
plot_graph_and_save_audio,
base_path=base_path,
start_of_sentence=start_of_sentence, end_of_sentence=end_of_sentence,
pre_word_num=pre_word_num, post_word_num=post_word_num,
pre_surplus_idx=pre_surplus_idx, post_surplus_idx=post_surplus_idx,
use_short_concat=use_short_concat,
use_manual_attention=use_manual_attention,
librosa_trim=librosa_trim,
attention_trim=attention_trim,
time_str=time_str)
return parallel_run(fn, items,
desc="plot_graph_and_save_audio", parallel=False)
input_lengths = np.argmax(np.array(sequences) == 1, 1)
fetches = [
#self.wav_output,
self.model.linear_outputs,
self.model.alignments,
]
feed_dict = {
self.model.inputs: sequences,
self.model.input_lengths: input_lengths,
}
if base_alignment_path is None:
feed_dict.update({
self.model.manual_alignments: np.zeros([1, 1, 1]),
self.model.is_manual_attention: False,
})
else:
manual_alignments = []
alignment_path = os.path.join(
base_alignment_path,
os.path.basename(base_path))
for idx in range(len(sequences)):
numpy_path = "{}.{}.npy".format(alignment_path, idx)
manual_alignments.append(np.load(numpy_path))
alignments_T = np.transpose(manual_alignments, [0, 2, 1])
feed_dict.update({
self.model.manual_alignments: alignments_T,
self.model.is_manual_attention: True,
})
if speaker_ids is not None:
if type(speaker_ids) == dict:
speaker_embed_table = sess.run(
self.model.speaker_embed_table)
speaker_embed = [speaker_ids[speaker_id] * \
speaker_embed_table[speaker_id] for speaker_id in speaker_ids]
feed_dict.update({
self.model.speaker_embed_table: np.tile()
})
else:
feed_dict[self.model.speaker_id] = speaker_ids
wavs, alignments = \
self.sess.run(fetches, feed_dict=feed_dict)
results = plot_and_save_parallel(
wavs, alignments, True)
if manual_attention_mode > 0:
# argmax one hot
if manual_attention_mode == 1:
alignments_T = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
new_alignments = np.zeros_like(alignments_T)
for idx in range(len(alignments)):
argmax = alignments[idx].argmax(1)
new_alignments[idx][(argmax, range(len(argmax)))] = 1
# sharpening
elif manual_attention_mode == 2:
new_alignments = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
for idx in range(len(alignments)):
var = np.var(new_alignments[idx], 1)
mean_var = var[:input_lengths[idx]].mean()
new_alignments = np.pow(new_alignments[idx], 2)
# prunning
elif manual_attention_mode == 3:
new_alignments = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
for idx in range(len(alignments)):
argmax = alignments[idx].argmax(1)
new_alignments[idx][(argmax, range(len(argmax)))] = 1
feed_dict.update({
self.model.manual_alignments: new_alignments,
self.model.is_manual_attention: True,
})
new_wavs, new_alignments = \
self.sess.run(fetches, feed_dict=feed_dict)
results = plot_and_save_parallel(
new_wavs, new_alignments, True)
return results
def plot_graph_and_save_audio(args,
base_path=None,
start_of_sentence=None, end_of_sentence=None,
pre_word_num=0, post_word_num=0,
pre_surplus_idx=0, post_surplus_idx=1,
use_short_concat=False,
use_manual_attention=False, save_alignment=False,
librosa_trim=False, attention_trim=False,
time_str=None):
idx, (wav, alignment, path, text, sequence) = args
if base_path:
plot_path = "{}/{}.png".format(base_path, get_time())
elif path:
plot_path = path.rsplit('.', 1)[0] + ".png"
else:
plot_path = None
#plot_path = add_prefix(plot_path, time_str)
if use_manual_attention:
plot_path = add_postfix(plot_path, "manual")
if plot_path:
plot.plot_alignment(alignment, plot_path, text=text)
if use_short_concat:
wav = short_concat(
wav, alignment, text,
start_of_sentence, end_of_sentence,
pre_word_num, post_word_num,
pre_surplus_idx, post_surplus_idx)
if attention_trim and end_of_sentence:
end_idx_counter = 0
attention_argmax = alignment.argmax(0)
end_idx = min(len(sequence) - 1, max(attention_argmax))
max_counter = min((attention_argmax == end_idx).sum(), 5)
for jdx, attend_idx in enumerate(attention_argmax):
if len(attention_argmax) > jdx + 1:
if attend_idx == end_idx:
end_idx_counter += 1
if attend_idx == end_idx and attention_argmax[jdx + 1] > end_idx:
break
if end_idx_counter >= max_counter:
break
else:
break
spec_end_idx = hparams.reduction_factor * jdx + 3
wav = wav[:spec_end_idx]
audio_out = inv_spectrogram(wav.T)
if librosa_trim and end_of_sentence:
yt, index = librosa.effects.trim(audio_out,
frame_length=5120, hop_length=256, top_db=50)
audio_out = audio_out[:index[-1]]
if save_alignment:
alignment_path = "{}/{}.npy".format(base_path, idx)
np.save(alignment_path, alignment, allow_pickle=False)
if path or base_path:
if path:
current_path = add_postfix(path, idx)
elif base_path:
current_path = plot_path.replace(".png", ".wav")
save_audio(audio_out, current_path)
return True
else:
io_out = io.BytesIO()
save_audio(audio_out, io_out)
result = io_out.getvalue()
return result
def get_most_recent_checkpoint(checkpoint_dir, checkpoint_step=None):
if checkpoint_step is None:
checkpoint_paths = [path for path in glob("{}/*.ckpt-*.data-*".format(checkpoint_dir))]
idxes = [int(os.path.basename(path).split('-')[1].split('.')[0]) for path in checkpoint_paths]
max_idx = max(idxes)
else:
max_idx = checkpoint_step
lastest_checkpoint = os.path.join(checkpoint_dir, "model.ckpt-{}".format(max_idx))
print(" [*] Found lastest checkpoint: {}".format(lastest_checkpoint))
return lastest_checkpoint
def short_concat(
wav, alignment, text,
start_of_sentence, end_of_sentence,
pre_word_num, post_word_num,
pre_surplus_idx, post_surplus_idx):
# np.array(list(decomposed_text))[attention_argmax]
attention_argmax = alignment.argmax(0)
if not start_of_sentence and pre_word_num > 0:
surplus_decomposed_text = decompose_ko_text("".join(text.split()[0]))
start_idx = len(surplus_decomposed_text) + 1
for idx, attend_idx in enumerate(attention_argmax):
if attend_idx == start_idx and attention_argmax[idx - 1] < start_idx:
break
wav_start_idx = hparams.reduction_factor * idx - 1 - pre_surplus_idx
else:
wav_start_idx = 0
if not end_of_sentence and post_word_num > 0:
surplus_decomposed_text = decompose_ko_text("".join(text.split()[-1]))
end_idx = len(decomposed_text.replace(surplus_decomposed_text, '')) - 1
for idx, attend_idx in enumerate(attention_argmax):
if attend_idx == end_idx and attention_argmax[idx + 1] > end_idx:
break
wav_end_idx = hparams.reduction_factor * idx + 1 + post_surplus_idx
else:
if True: # attention based split
if end_of_sentence:
end_idx = min(len(decomposed_text) - 1, max(attention_argmax))
else:
surplus_decomposed_text = decompose_ko_text("".join(text.split()[-1]))
end_idx = len(decomposed_text.replace(surplus_decomposed_text, '')) - 1
while True:
if end_idx in attention_argmax:
break
end_idx -= 1
end_idx_counter = 0
for idx, attend_idx in enumerate(attention_argmax):
if len(attention_argmax) > idx + 1:
if attend_idx == end_idx:
end_idx_counter += 1
if attend_idx == end_idx and attention_argmax[idx + 1] > end_idx:
break
if end_idx_counter > 5:
break
else:
break
wav_end_idx = hparams.reduction_factor * idx + 1 + post_surplus_idx
else:
wav_end_idx = None
wav = wav[wav_start_idx:wav_end_idx]
if end_of_sentence:
wav = np.lib.pad(wav, ((0, 20), (0, 0)), 'constant', constant_values=0)
else:
wav = np.lib.pad(wav, ((0, 10), (0, 0)), 'constant', constant_values=0)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--load_path', required=True)
parser.add_argument('--sample_path', default="samples")
parser.add_argument('--text', required=True)
parser.add_argument('--num_speakers', default=1, type=int)
parser.add_argument('--speaker_id', default=0, type=int)
parser.add_argument('--checkpoint_step', default=None, type=int)
config = parser.parse_args()
makedirs(config.sample_path)
synthesizer = Synthesizer()
synthesizer.load(config.load_path, config.num_speakers, config.checkpoint_step)
audio = synthesizer.synthesize(
texts=[config.text],
base_path=config.sample_path,
speaker_ids=[config.speaker_id],
attention_trim=False)[0]