389 lines
14 KiB
Python
389 lines
14 KiB
Python
import io
|
|
import os
|
|
import re
|
|
import librosa
|
|
import argparse
|
|
import numpy as np
|
|
from glob import glob
|
|
from tqdm import tqdm
|
|
import tensorflow as tf
|
|
from functools import partial
|
|
|
|
from hparams import hparams
|
|
from models import create_model, get_most_recent_checkpoint
|
|
from audio import save_audio, inv_spectrogram, inv_preemphasis, \
|
|
inv_spectrogram_tensorflow
|
|
from utils import plot, PARAMS_NAME, load_json, load_hparams, \
|
|
add_prefix, add_postfix, get_time, parallel_run, makedirs
|
|
|
|
from text.korean import tokenize
|
|
from text import text_to_sequence, sequence_to_text
|
|
|
|
|
|
class Synthesizer(object):
|
|
def close(self):
|
|
tf.reset_default_graph()
|
|
self.sess.close()
|
|
|
|
def load(self, checkpoint_path, num_speakers=2, checkpoint_step=None, model_name='tacotron'):
|
|
self.num_speakers = num_speakers
|
|
|
|
if os.path.isdir(checkpoint_path):
|
|
load_path = checkpoint_path
|
|
checkpoint_path = get_most_recent_checkpoint(checkpoint_path, checkpoint_step)
|
|
else:
|
|
load_path = os.path.dirname(checkpoint_path)
|
|
|
|
print('Constructing model: %s' % model_name)
|
|
|
|
inputs = tf.placeholder(tf.int32, [None, None], 'inputs')
|
|
input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
|
|
|
|
batch_size = tf.shape(inputs)[0]
|
|
speaker_id = tf.placeholder_with_default(
|
|
tf.zeros([batch_size], dtype=tf.int32), [None], 'speaker_id')
|
|
|
|
load_hparams(hparams, load_path)
|
|
with tf.variable_scope('model') as scope:
|
|
self.model = create_model(hparams)
|
|
|
|
self.model.initialize(
|
|
inputs, input_lengths,
|
|
self.num_speakers, speaker_id)
|
|
self.wav_output = \
|
|
inv_spectrogram_tensorflow(self.model.linear_outputs)
|
|
|
|
print('Loading checkpoint: %s' % checkpoint_path)
|
|
|
|
sess_config = tf.ConfigProto(
|
|
allow_soft_placement=True,
|
|
intra_op_parallelism_threads=1,
|
|
inter_op_parallelism_threads=2)
|
|
sess_config.gpu_options.allow_growth = True
|
|
|
|
self.sess = tf.Session(config=sess_config)
|
|
self.sess.run(tf.global_variables_initializer())
|
|
saver = tf.train.Saver()
|
|
saver.restore(self.sess, checkpoint_path)
|
|
|
|
def synthesize(self,
|
|
texts=None, tokens=None,
|
|
base_path=None, paths=None, speaker_ids=None,
|
|
start_of_sentence=None, end_of_sentence=True,
|
|
pre_word_num=0, post_word_num=0,
|
|
pre_surplus_idx=0, post_surplus_idx=1,
|
|
use_short_concat=False,
|
|
manual_attention_mode=0,
|
|
base_alignment_path=None,
|
|
librosa_trim=False,
|
|
attention_trim=True):
|
|
|
|
# Possible inputs:
|
|
# 1) text=text
|
|
# 2) text=texts
|
|
# 3) tokens=tokens, texts=texts # use texts as guide
|
|
|
|
if type(texts) == str:
|
|
texts = [texts]
|
|
|
|
if texts is not None and tokens is None:
|
|
sequences = [text_to_sequence(text) for text in texts]
|
|
elif tokens is not None:
|
|
sequences = tokens
|
|
|
|
if paths is None:
|
|
paths = [None] * len(sequences)
|
|
if texts is None:
|
|
texts = [None] * len(sequences)
|
|
|
|
time_str = get_time()
|
|
def plot_and_save_parallel(
|
|
wavs, alignments, use_manual_attention):
|
|
|
|
items = list(enumerate(zip(
|
|
wavs, alignments, paths, texts, sequences)))
|
|
|
|
fn = partial(
|
|
plot_graph_and_save_audio,
|
|
base_path=base_path,
|
|
start_of_sentence=start_of_sentence, end_of_sentence=end_of_sentence,
|
|
pre_word_num=pre_word_num, post_word_num=post_word_num,
|
|
pre_surplus_idx=pre_surplus_idx, post_surplus_idx=post_surplus_idx,
|
|
use_short_concat=use_short_concat,
|
|
use_manual_attention=use_manual_attention,
|
|
librosa_trim=librosa_trim,
|
|
attention_trim=attention_trim,
|
|
time_str=time_str)
|
|
return parallel_run(fn, items,
|
|
desc="plot_graph_and_save_audio", parallel=False)
|
|
|
|
input_lengths = np.argmax(np.array(sequences) == 1, 1)
|
|
|
|
fetches = [
|
|
#self.wav_output,
|
|
self.model.linear_outputs,
|
|
self.model.alignments,
|
|
]
|
|
|
|
feed_dict = {
|
|
self.model.inputs: sequences,
|
|
self.model.input_lengths: input_lengths,
|
|
}
|
|
if base_alignment_path is None:
|
|
feed_dict.update({
|
|
self.model.manual_alignments: np.zeros([1, 1, 1]),
|
|
self.model.is_manual_attention: False,
|
|
})
|
|
else:
|
|
manual_alignments = []
|
|
alignment_path = os.path.join(
|
|
base_alignment_path,
|
|
os.path.basename(base_path))
|
|
|
|
for idx in range(len(sequences)):
|
|
numpy_path = "{}.{}.npy".format(alignment_path, idx)
|
|
manual_alignments.append(np.load(numpy_path))
|
|
|
|
alignments_T = np.transpose(manual_alignments, [0, 2, 1])
|
|
feed_dict.update({
|
|
self.model.manual_alignments: alignments_T,
|
|
self.model.is_manual_attention: True,
|
|
})
|
|
|
|
if speaker_ids is not None:
|
|
if type(speaker_ids) == dict:
|
|
speaker_embed_table = sess.run(
|
|
self.model.speaker_embed_table)
|
|
|
|
speaker_embed = [speaker_ids[speaker_id] * \
|
|
speaker_embed_table[speaker_id] for speaker_id in speaker_ids]
|
|
feed_dict.update({
|
|
self.model.speaker_embed_table: np.tile()
|
|
})
|
|
else:
|
|
feed_dict[self.model.speaker_id] = speaker_ids
|
|
|
|
wavs, alignments = \
|
|
self.sess.run(fetches, feed_dict=feed_dict)
|
|
results = plot_and_save_parallel(
|
|
wavs, alignments, True)
|
|
|
|
if manual_attention_mode > 0:
|
|
# argmax one hot
|
|
if manual_attention_mode == 1:
|
|
alignments_T = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
|
|
new_alignments = np.zeros_like(alignments_T)
|
|
|
|
for idx in range(len(alignments)):
|
|
argmax = alignments[idx].argmax(1)
|
|
new_alignments[idx][(argmax, range(len(argmax)))] = 1
|
|
# sharpening
|
|
elif manual_attention_mode == 2:
|
|
new_alignments = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
|
|
|
|
for idx in range(len(alignments)):
|
|
var = np.var(new_alignments[idx], 1)
|
|
mean_var = var[:input_lengths[idx]].mean()
|
|
|
|
new_alignments = np.pow(new_alignments[idx], 2)
|
|
# prunning
|
|
elif manual_attention_mode == 3:
|
|
new_alignments = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
|
|
|
|
for idx in range(len(alignments)):
|
|
argmax = alignments[idx].argmax(1)
|
|
new_alignments[idx][(argmax, range(len(argmax)))] = 1
|
|
|
|
feed_dict.update({
|
|
self.model.manual_alignments: new_alignments,
|
|
self.model.is_manual_attention: True,
|
|
})
|
|
|
|
new_wavs, new_alignments = \
|
|
self.sess.run(fetches, feed_dict=feed_dict)
|
|
results = plot_and_save_parallel(
|
|
new_wavs, new_alignments, True)
|
|
|
|
return results
|
|
|
|
def plot_graph_and_save_audio(args,
|
|
base_path=None,
|
|
start_of_sentence=None, end_of_sentence=None,
|
|
pre_word_num=0, post_word_num=0,
|
|
pre_surplus_idx=0, post_surplus_idx=1,
|
|
use_short_concat=False,
|
|
use_manual_attention=False, save_alignment=False,
|
|
librosa_trim=False, attention_trim=False,
|
|
time_str=None):
|
|
|
|
idx, (wav, alignment, path, text, sequence) = args
|
|
|
|
if base_path:
|
|
plot_path = "{}/{}.png".format(base_path, get_time())
|
|
elif path:
|
|
plot_path = path.rsplit('.', 1)[0] + ".png"
|
|
else:
|
|
plot_path = None
|
|
|
|
#plot_path = add_prefix(plot_path, time_str)
|
|
if use_manual_attention:
|
|
plot_path = add_postfix(plot_path, "manual")
|
|
|
|
if plot_path:
|
|
plot.plot_alignment(alignment, plot_path, text=text)
|
|
|
|
if use_short_concat:
|
|
wav = short_concat(
|
|
wav, alignment, text,
|
|
start_of_sentence, end_of_sentence,
|
|
pre_word_num, post_word_num,
|
|
pre_surplus_idx, post_surplus_idx)
|
|
|
|
if attention_trim and end_of_sentence:
|
|
end_idx_counter = 0
|
|
attention_argmax = alignment.argmax(0)
|
|
end_idx = min(len(sequence) - 1, max(attention_argmax))
|
|
max_counter = min((attention_argmax == end_idx).sum(), 5)
|
|
|
|
for jdx, attend_idx in enumerate(attention_argmax):
|
|
if len(attention_argmax) > jdx + 1:
|
|
if attend_idx == end_idx:
|
|
end_idx_counter += 1
|
|
|
|
if attend_idx == end_idx and attention_argmax[jdx + 1] > end_idx:
|
|
break
|
|
|
|
if end_idx_counter >= max_counter:
|
|
break
|
|
else:
|
|
break
|
|
|
|
spec_end_idx = hparams.reduction_factor * jdx + 3
|
|
wav = wav[:spec_end_idx]
|
|
|
|
audio_out = inv_spectrogram(wav.T)
|
|
|
|
if librosa_trim and end_of_sentence:
|
|
yt, index = librosa.effects.trim(audio_out,
|
|
frame_length=5120, hop_length=256, top_db=50)
|
|
audio_out = audio_out[:index[-1]]
|
|
|
|
if save_alignment:
|
|
alignment_path = "{}/{}.npy".format(base_path, idx)
|
|
np.save(alignment_path, alignment, allow_pickle=False)
|
|
|
|
if path or base_path:
|
|
if path:
|
|
current_path = add_postfix(path, idx)
|
|
elif base_path:
|
|
current_path = plot_path.replace(".png", ".wav")
|
|
|
|
save_audio(audio_out, current_path)
|
|
return True
|
|
else:
|
|
io_out = io.BytesIO()
|
|
save_audio(audio_out, io_out)
|
|
result = io_out.getvalue()
|
|
return result
|
|
|
|
def get_most_recent_checkpoint(checkpoint_dir, checkpoint_step=None):
|
|
if checkpoint_step is None:
|
|
checkpoint_paths = [path for path in glob("{}/*.ckpt-*.data-*".format(checkpoint_dir))]
|
|
idxes = [int(os.path.basename(path).split('-')[1].split('.')[0]) for path in checkpoint_paths]
|
|
|
|
max_idx = max(idxes)
|
|
else:
|
|
max_idx = checkpoint_step
|
|
lastest_checkpoint = os.path.join(checkpoint_dir, "model.ckpt-{}".format(max_idx))
|
|
print(" [*] Found lastest checkpoint: {}".format(lastest_checkpoint))
|
|
return lastest_checkpoint
|
|
|
|
def short_concat(
|
|
wav, alignment, text,
|
|
start_of_sentence, end_of_sentence,
|
|
pre_word_num, post_word_num,
|
|
pre_surplus_idx, post_surplus_idx):
|
|
|
|
# np.array(list(decomposed_text))[attention_argmax]
|
|
attention_argmax = alignment.argmax(0)
|
|
|
|
if not start_of_sentence and pre_word_num > 0:
|
|
surplus_decomposed_text = decompose_ko_text("".join(text.split()[0]))
|
|
start_idx = len(surplus_decomposed_text) + 1
|
|
|
|
for idx, attend_idx in enumerate(attention_argmax):
|
|
if attend_idx == start_idx and attention_argmax[idx - 1] < start_idx:
|
|
break
|
|
|
|
wav_start_idx = hparams.reduction_factor * idx - 1 - pre_surplus_idx
|
|
else:
|
|
wav_start_idx = 0
|
|
|
|
if not end_of_sentence and post_word_num > 0:
|
|
surplus_decomposed_text = decompose_ko_text("".join(text.split()[-1]))
|
|
end_idx = len(decomposed_text.replace(surplus_decomposed_text, '')) - 1
|
|
|
|
for idx, attend_idx in enumerate(attention_argmax):
|
|
if attend_idx == end_idx and attention_argmax[idx + 1] > end_idx:
|
|
break
|
|
|
|
wav_end_idx = hparams.reduction_factor * idx + 1 + post_surplus_idx
|
|
else:
|
|
if True: # attention based split
|
|
if end_of_sentence:
|
|
end_idx = min(len(decomposed_text) - 1, max(attention_argmax))
|
|
else:
|
|
surplus_decomposed_text = decompose_ko_text("".join(text.split()[-1]))
|
|
end_idx = len(decomposed_text.replace(surplus_decomposed_text, '')) - 1
|
|
|
|
while True:
|
|
if end_idx in attention_argmax:
|
|
break
|
|
end_idx -= 1
|
|
|
|
end_idx_counter = 0
|
|
for idx, attend_idx in enumerate(attention_argmax):
|
|
if len(attention_argmax) > idx + 1:
|
|
if attend_idx == end_idx:
|
|
end_idx_counter += 1
|
|
|
|
if attend_idx == end_idx and attention_argmax[idx + 1] > end_idx:
|
|
break
|
|
|
|
if end_idx_counter > 5:
|
|
break
|
|
else:
|
|
break
|
|
|
|
wav_end_idx = hparams.reduction_factor * idx + 1 + post_surplus_idx
|
|
else:
|
|
wav_end_idx = None
|
|
|
|
wav = wav[wav_start_idx:wav_end_idx]
|
|
|
|
if end_of_sentence:
|
|
wav = np.lib.pad(wav, ((0, 20), (0, 0)), 'constant', constant_values=0)
|
|
else:
|
|
wav = np.lib.pad(wav, ((0, 10), (0, 0)), 'constant', constant_values=0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--load_path', required=True)
|
|
parser.add_argument('--sample_path', default="samples")
|
|
parser.add_argument('--text', required=True)
|
|
parser.add_argument('--num_speakers', default=1, type=int)
|
|
parser.add_argument('--speaker_id', default=0, type=int)
|
|
parser.add_argument('--checkpoint_step', default=None, type=int)
|
|
config = parser.parse_args()
|
|
|
|
makedirs(config.sample_path)
|
|
|
|
synthesizer = Synthesizer()
|
|
synthesizer.load(config.load_path, config.num_speakers, config.checkpoint_step)
|
|
|
|
audio = synthesizer.synthesize(
|
|
texts=[config.text],
|
|
base_path=config.sample_path,
|
|
speaker_ids=[config.speaker_id],
|
|
attention_trim=False)[0]
|