initial commit
This commit is contained in:
commit
0a6f6db0b1
56 changed files with 49395 additions and 0 deletions
132
.gitignore
vendored
Normal file
132
.gitignore
vendored
Normal file
|
@ -0,0 +1,132 @@
|
|||
# Text
|
||||
*.png
|
||||
*.txt
|
||||
*.json
|
||||
*.csv
|
||||
|
||||
# Data
|
||||
logs
|
||||
*.npy
|
||||
*.npz
|
||||
*.tar
|
||||
*.tar.gz
|
||||
|
||||
# Media
|
||||
*.mp4
|
||||
*.mp3
|
||||
*.flac
|
||||
*.wav
|
||||
*.ts
|
||||
|
||||
.DS_Store
|
||||
|
||||
# Created by https://www.gitignore.io/api/python,vim
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*,cover
|
||||
.hypothesis/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# dotenv
|
||||
.env
|
||||
|
||||
# virtualenv
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
|
||||
### Vim ###
|
||||
# swap
|
||||
[._]*.s[a-v][a-z]
|
||||
[._]*.sw[a-p]
|
||||
[._]s[a-v][a-z]
|
||||
[._]sw[a-p]
|
||||
# session
|
||||
Session.vim
|
||||
# temporary
|
||||
.netrwhist
|
||||
*~
|
||||
# auto-generated tag files
|
||||
tags
|
||||
|
||||
# End of https://www.gitignore.io/api/python,vim
|
3
DISCLAIMER
Normal file
3
DISCLAIMER
Normal file
|
@ -0,0 +1,3 @@
|
|||
This is not an official [DEVSISTERS](http://devsisters.com/) product and is not responsible for misuse or for any damage that you may cause. You agree that you use this software at your own risk.
|
||||
|
||||
이것은 [데브시스터즈](http://devsisters.com/)의 공식적인 제품이 아닙니다. [데브시스터즈](http://devsisters.com )는 이 코드를 잘못 사용했을 시 발생한 문제나 이슈에 대한 책임을 지지 않으며 이 소프트웨어의 사용은 사용자 자신에>게 전적으로 책임이 있습니다.
|
40
LICENSE
Normal file
40
LICENSE
Normal file
|
@ -0,0 +1,40 @@
|
|||
Copyright (c) 2017 Devsisters
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
|
||||
|
||||
Copyright (c) 2017 Keith Ito
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
168
README.md
Normal file
168
README.md
Normal file
|
@ -0,0 +1,168 @@
|
|||
# Multi-Speaker Tacotron in TensorFlow
|
||||
|
||||
[[한국어 가이드](./README_ko.md)]
|
||||
|
||||
TensorFlow implementation of:
|
||||
|
||||
- [Deep Voice 2: Multi-Speaker Neural Text-to-Speech](https://arxiv.org/abs/1705.08947)
|
||||
- [Listening while Speaking: Speech Chain by Deep Learning](https://arxiv.org/abs/1707.04879)
|
||||
- [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135)
|
||||
|
||||
Samples audios (in Korean) can be found [here](http://carpedm20.github.io/tacotron/en.html).
|
||||
|
||||
![model](./assets/model.png)
|
||||
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.6+
|
||||
- [Tensorflow 1.3](https://www.tensorflow.org/install/)
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Install prerequisites
|
||||
|
||||
After preparing [Tensorflow](https://www.tensorflow.org/install/), install prerequisites with:
|
||||
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
If you want to synthesize a speech in Korean dicrectly, follow [2-3. Download pre-trained models](#2-3-download-pre-trained-models).
|
||||
|
||||
|
||||
### 2-1. Generate custom datasets
|
||||
|
||||
The `datasets` directory should look like:
|
||||
|
||||
datasets
|
||||
├── jtbc
|
||||
│ ├── alignment.json
|
||||
│ └── audio
|
||||
│ ├── 1.mp3
|
||||
│ ├── 2.mp3
|
||||
│ ├── 3.mp3
|
||||
│ └── ...
|
||||
└── YOUR_DATASET
|
||||
├── alignment.json
|
||||
└── audio
|
||||
├── 1.mp3
|
||||
├── 2.mp3
|
||||
├── 3.mp3
|
||||
└── ...
|
||||
|
||||
and `YOUR_DATASET/alignment.json` should look like:
|
||||
|
||||
{
|
||||
"./datasets/YOUR_DATASET/audio/001.mp3": "My name is Taehoon Kim.",
|
||||
"./datasets/YOUR_DATASET/audio/002.mp3": "The buses aren't the problem.",
|
||||
"./datasets/YOUR_DATASET/audio/003.mp3": "They have discovered a new particle.",
|
||||
}
|
||||
|
||||
After you prepare as described, you should genearte preprocessed data with:
|
||||
|
||||
python -m datasets.generate_data ./datasets/YOUR_DATASET/alignment.json
|
||||
|
||||
|
||||
### 2-2. Generate Korean datasets
|
||||
|
||||
You can generate datasets for 3 public Korean figures including:
|
||||
|
||||
1. [Sohn Suk-hee](https://en.wikipedia.org/wiki/Sohn_Suk-hee): anchor and president of JTBC
|
||||
2. [Park Geun-hye](https://en.wikipedia.org/wiki/Park_Geun-hye): a former President of South Korea
|
||||
3. [Moon Jae-in](https://en.wikipedia.org/wiki/Moon_Jae-in): the current President of South Korea
|
||||
|
||||
Each dataset can be generated with following scripts:
|
||||
|
||||
./scripts/prepare_son.sh # Sohn Suk-hee
|
||||
./scripts/prepare_park.sh # Park Geun-hye
|
||||
./scripts/prepare_moon.sh # Moon Jae-in
|
||||
|
||||
|
||||
Each script execute below commands. (explain with `son` dataset)
|
||||
|
||||
0. To automate an alignment between sounds and texts, prepare `GOOGLE_APPLICATION_CREDENTIALS` to use [Google Speech Recognition API](https://cloud.google.com/speech/). To get credentials, read [this](https://developers.google.com/identity/protocols/application-default-credentials).
|
||||
|
||||
export GOOGLE_APPLICATION_CREDENTIALS="YOUR-GOOGLE.CREDENTIALS.json"
|
||||
|
||||
1. Download speech(or video) and text.
|
||||
|
||||
python -m datasets.son.download
|
||||
|
||||
2. Segment all audios on silence.
|
||||
|
||||
python -m audio.silence --audio_pattern "./datasets/son/audio/*.wav" --method=pydub
|
||||
|
||||
3. By using [Google Speech Recognition API](https://cloud.google.com/speech/), we predict sentences for all segmented audios. (this is optional for `moon` and `park` because they already have `alignment.json`)
|
||||
|
||||
python -m recognition.google --audio_pattern "./datasets/son/audio/*.*.wav"
|
||||
|
||||
4. By comparing original text and recognised text, save `audio<->text` pair information into `./datasets/son/alignment.json`.
|
||||
|
||||
python -m recognition.alignment --recognition_path "./datasets/son/recognition.json" --score_threshold=0.5
|
||||
|
||||
5. Finally, generated numpy files which will be used in training.
|
||||
|
||||
python3 -m datasets.synthesizer_data ./datasets/son/alignment.json
|
||||
|
||||
Because the automatic generation is extremely naive, the dataset is noisy. However, if you have enough datasets (20+ hours with random initialization or 5+ hours with pretrained model initialization), you can expect an acceptable quality of audio synthesis.
|
||||
|
||||
|
||||
### 2-3. Download pre-trained models
|
||||
|
||||
You can download a pre-trained models or generate audio. Available models are:
|
||||
|
||||
1. Single speaker model for [Sohn Suk-hee](https://en.wikipedia.org/wiki/Sohn_Suk-hee).
|
||||
|
||||
python3 download.py son
|
||||
|
||||
2. Single speaker model for [Park Geun-hye](https://en.wikipedia.org/wiki/Park_Geun-hye).
|
||||
|
||||
python3 download.py park
|
||||
|
||||
After you donwload pre-trained models, you can generate voices as follows:
|
||||
|
||||
python3 synthesizer.py --load_path logs/son-20171015 --text "이거 실화냐?"
|
||||
python3 synthesizer.py --load_path logs/park-20171015 --text "이거 실화냐?"
|
||||
|
||||
**WARNING: The two pre-trained models are being made available for research purpose only.**
|
||||
|
||||
|
||||
### 3. Train a model
|
||||
|
||||
To train a single-speaker model:
|
||||
|
||||
python train.py --data_path=datasets/jtbc
|
||||
python train.py --data_path=datasets/park --initialize_path=PATH_TO_CHECKPOINT
|
||||
|
||||
To train a multi-speaker model:
|
||||
|
||||
python train.py --data_path=datasets/jtbc,datasets/park
|
||||
|
||||
If you don't have good and enough (10+ hours) dataset, it would be better to use `--initialize_path` to use a well-trained model as initial parameters.
|
||||
|
||||
|
||||
### 4. Synthesize audio
|
||||
|
||||
You can train your own models with:
|
||||
|
||||
python3 app.py --load_path logs/park-20171015 --num_speakers=1
|
||||
|
||||
or generate audio directly with:
|
||||
|
||||
python3 synthesizer.py --load_path logs/park-20171015 --text "이거 실화냐?"
|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an official [DEVSISTERS](http://devsisters.com/) product. This project is not responsible for misuse or for any damage that you may cause. You agree that you use this software at your own risk.
|
||||
|
||||
|
||||
## References
|
||||
|
||||
- [Keith Ito](https://github.com/keithito)'s [tacotron](https://github.com/keithito/tacotron)
|
||||
- [DEVIEW 2017 presentation](https://deview.kr/2017/schedule/182) (Korean)
|
||||
|
||||
|
||||
## Author
|
||||
|
||||
Taehoon Kim / [@carpedm20](http://carpedm20.github.io/)
|
169
README_ko.md
Normal file
169
README_ko.md
Normal file
|
@ -0,0 +1,169 @@
|
|||
# D.Voice: 오픈소스 딥러닝 음성 합성 엔진
|
||||
|
||||
[[English Guide](./README.md)]
|
||||
|
||||
D.Voice는 TensorFlow로 구현된 오픈소스 딥러닝 음성 합성 엔진입니다. 이 프로젝트는:
|
||||
|
||||
- [Deep Voice 2: Multi-Speaker Neural Text-to-Speech](https://arxiv.org/abs/1705.08947)
|
||||
- [Listening while Speaking: Speech Chain by Deep Learning](https://arxiv.org/abs/1707.04879)
|
||||
- [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135)
|
||||
|
||||
위 세 논문의 모델들의 구현을 포함하고 있습니다. 음성 데모는 [여기](http://carpedm20.github.io/tacotron/)서 들어보실 수 있습니다.
|
||||
|
||||
![model](./assets/model.png)
|
||||
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.6+
|
||||
- [Tensorflow 1.3](https://www.tensorflow.org/install/)
|
||||
|
||||
|
||||
## 사용 방법
|
||||
|
||||
### 1. 필수 라이브러리 설치
|
||||
|
||||
[Tensorflow 1.3](https://www.tensorflow.org/install/)를 설치한 후, 아래 명령어로 필수 라이브러리를 설치합니다.
|
||||
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
바로 음성을 만들고 싶으면 [2-3. 학습된 모델 다운받기](#2-4-미리-학습된-모델-다운받기)를 따라하시면 됩니다.
|
||||
|
||||
|
||||
### 2-1. 학습할 데이터 준비하기
|
||||
|
||||
`datasets` 디렉토리는 다음과 같이 구성되어야 합니다:
|
||||
|
||||
datasets
|
||||
├── son
|
||||
│ ├── alignment.json
|
||||
│ └── audio
|
||||
│ ├── 1.mp3
|
||||
│ ├── 2.mp3
|
||||
│ ├── 3.mp3
|
||||
│ └── ...
|
||||
└── 아무개
|
||||
├── alignment.json
|
||||
└── audio
|
||||
├── 1.mp3
|
||||
├── 2.mp3
|
||||
├── 3.mp3
|
||||
└── ...
|
||||
|
||||
그리고 `아무개/alignment.json`는 아래와 같은 포멧으로 `json` 형태로 준비해 주세요.
|
||||
|
||||
{
|
||||
"./datasets/아무개/audio/001.mp3": "존경하는 국민 여러분",
|
||||
"./datasets/아무개/audio/002.mp3": "국회의장과 국회의원 여러분",
|
||||
"./datasets/아무개/audio/003.mp3": "저는 오늘",
|
||||
}
|
||||
|
||||
`datasets`와 `아무개/alignment.json`가 준비되면, 아래 명령어로 학습 데이터를 만드시면 됩니다:
|
||||
|
||||
python3 -m datasets.synthesizer_data ./datasets/아무개/alignment.json
|
||||
|
||||
|
||||
### 2-2. {손석희, 문재인, 박근혜} 데이터 만들기
|
||||
|
||||
만약 음성 데이터가 없으시다면, 3명의 한국인 음성 데이터를 만드실 수 있습니다:
|
||||
|
||||
1. [손석희](https://ko.wikipedia.org/wiki/%EC%86%90%EC%84%9D%ED%9D%AC)
|
||||
2. [박근혜](https://ko.wikipedia.org/wiki/%EB%B0%95%EA%B7%BC%ED%98%9C)
|
||||
3. [문재인](https://ko.wikipedia.org/wiki/%EB%AC%B8%EC%9E%AC%EC%9D%B8)
|
||||
|
||||
각각의 데이터는 아래 스크립트로 만들 수 있으며,
|
||||
|
||||
./scripts/prepare_son.sh # 손석희
|
||||
./scripts/prepare_park.sh # 박근혜
|
||||
./scripts/prepare_moon.sh # 문재인
|
||||
|
||||
|
||||
각 스크립트는 아래와 같은 명령어를 실행합니다. (son 기준으로 설명합니다)
|
||||
|
||||
0. 자동으로 `음성<->텍스트` 페어를 만들기 위해 [구글 음성 인식 API](https://cloud.google.com/speech/)를 사용하며, `GOOGLE_APPLICATION_CREDENTIALS`를 준비해야 합니다. `GOOGLE_APPLICATION_CREDENTIALS`를 얻기 위해서는 [여기](https://developers.google.com/identity/protocols/application-default-credentials)를 참고해 주세요.
|
||||
|
||||
export GOOGLE_APPLICATION_CREDENTIALS="YOUR-GOOGLE.CREDENTIALS.json"
|
||||
|
||||
1. 음성(혹은 영상)과 텍스트 데이터를 다운로드 받습니다.
|
||||
|
||||
python -m datasets.son.download
|
||||
|
||||
2. 음성을 정적을 기준으로 분리합니다.
|
||||
|
||||
python -m audio.silence --audio_pattern "./datasets/son/audio/*.wav" --method=pydub
|
||||
|
||||
3. 작게 분리된 음성들을 [Google Speech Recognition API](https://cloud.google.com/speech/)를 사용해 대략적인 문장들을 예측합니다.
|
||||
|
||||
python -m recognition.google --audio_pattern "./datasets/son/audio/*.*.wav"
|
||||
|
||||
4. 기존의 텍스트와 음성 인식으로 예측된 텍스트를 비교해 `음성<->텍스트` 쌍 정보를 `./datasets/son/alignment.json`에 저장합니다. (`moon`과 `park` 데이터셋은 `alignment.json`이 이미 있기 때문에 이 과정은 생략하셔도 됩니다.)
|
||||
|
||||
python -m recognition.alignment --recognition_path "./datasets/son/recognition.json" --score_threshold=0.5
|
||||
|
||||
5. 마지막으로 학습에 사용될 numpy 파일들을 만듭니다.
|
||||
|
||||
python3 -m datasets.synthesizer_data ./datasets/son/alignment.json
|
||||
|
||||
|
||||
자동화 과정이 굉장히 간단하기 때문에, 데이터에 노이즈가 많이 존재합니다. 하지만 오디오와 텍스트가 충분히 많이 있다면 (처음부터 학습시 20시간 이상, 미리 학습된 모델에서 학습시 5+시간 이상) 적당한 퀄리티의 음성 합성을 기대할 수 있습니다.
|
||||
|
||||
|
||||
### 2-4. 미리 학습된 모델 다운받기
|
||||
|
||||
미리 학습된 모델들을 사용해 음성을 만들거나 모델을 학습시킬 수 있습니다. 아래 모델 중 하나를 다운로드 받으시고:
|
||||
|
||||
1. 단일 화자 모델 - [손석희](https://ko.wikipedia.org/wiki/%EC%86%90%EC%84%9D%ED%9D%AC)
|
||||
|
||||
python3 download.py son
|
||||
|
||||
2. 단일 화자 모델 - [박근혜](https://ko.wikipedia.org/wiki/%EB%B0%95%EA%B7%BC%ED%98%9C)
|
||||
|
||||
python3 download.py park
|
||||
|
||||
학습된 모델을 다운받으시고, 아래 명령어로 음성을 만들어 낼 수 있습니다:
|
||||
|
||||
python3 synthesizer.py --load_path logs/son-20171015 --text "이거 실화냐?"
|
||||
python3 synthesizer.py --load_path logs/park-20171015 --text "이거 실화냐?"
|
||||
|
||||
**주의: 학습된 모델을 연구 이외의 목적으로 사용하는 것을 금지합니다.**
|
||||
|
||||
|
||||
### 3. 모델 학습하기
|
||||
|
||||
단일 화자 모델을 학습하려면:
|
||||
|
||||
python3 train.py --data_path=datasets/son
|
||||
python3 train.py --data_path=datasets/park --initialize_path logs/son-20171015
|
||||
|
||||
다중 화자 모델을 학습하려면:
|
||||
|
||||
python3 train.py --data_path=datasets/son,datasets/park
|
||||
|
||||
학습 데이터가 좋지 않다면 `--initialize_path`로 이미 학습된 모델의 파라미터로 초기화 해서 학습하시는 것이 좋습니다.
|
||||
|
||||
|
||||
### 4. 음성 만들기
|
||||
|
||||
모델을 학습시킨 후 웹 데모를 통해 음성을 만들거나:
|
||||
|
||||
python app.py --load_path logs/park-20171015 --num_speakers=1
|
||||
|
||||
아래 명령어로 음성을 만들 수 있습니다:
|
||||
|
||||
python3 synthesizer.py --load_path logs/park-20171015 --text "이거 실화냐?"
|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
이것은 [데브시스터즈](http://devsisters.com/)의 공식적인 제품이 아닙니다. [데브시스터즈](http://devsisters.com/)는 이 코드를 잘못 사용했을 시 발생한 문제나 이슈에 대한 책임을 지지 않으며 이 소프트웨어의 사용은 사용자 자신에게 전적으로 책임이 있습니다.
|
||||
|
||||
|
||||
## References
|
||||
|
||||
- [Keith Ito](https://github.com/keithito)'s [tacotron](https://github.com/keithito/tacotron)
|
||||
- [DEVIEW 2017 발표 자료](https://www.slideshare.net/carpedm20/deview-2017-80824162)
|
||||
|
||||
|
||||
## Author
|
||||
|
||||
Taehoon Kim / [@carpedm20](http://carpedm20.github.io/)
|
133
app.py
Normal file
133
app.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
#!flask/bin/python
|
||||
import os
|
||||
import hashlib
|
||||
import argparse
|
||||
from flask_cors import CORS
|
||||
from flask import Flask, request, render_template, jsonify, \
|
||||
send_from_directory, make_response, send_file
|
||||
|
||||
from hparams import hparams
|
||||
from audio import load_audio
|
||||
from synthesizer import Synthesizer
|
||||
from utils import str2bool, prepare_dirs, makedirs, add_postfix
|
||||
|
||||
ROOT_PATH = "web"
|
||||
AUDIO_DIR = "audio"
|
||||
AUDIO_PATH = os.path.join(ROOT_PATH, AUDIO_DIR)
|
||||
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
static_path = os.path.join(base_path, 'web/static')
|
||||
|
||||
global_config = None
|
||||
synthesizer = Synthesizer()
|
||||
app = Flask(__name__, root_path=ROOT_PATH, static_url_path='')
|
||||
CORS(app)
|
||||
|
||||
|
||||
def match_target_amplitude(sound, target_dBFS):
|
||||
change_in_dBFS = target_dBFS - sound.dBFS
|
||||
return sound.apply_gain(change_in_dBFS)
|
||||
|
||||
def amplify(path, keep_silence=300):
|
||||
sound = AudioSegment.from_file(path)
|
||||
|
||||
nonsilent_ranges = pydub.silence.detect_nonsilent(
|
||||
sound, silence_thresh=-50, min_silence_len=300)
|
||||
|
||||
new_sound = None
|
||||
for idx, (start_i, end_i) in enumerate(nonsilent_ranges):
|
||||
if idx == len(nonsilent_ranges) - 1:
|
||||
end_i = None
|
||||
|
||||
amplified_sound = \
|
||||
match_target_amplitude(sound[start_i:end_i], -20.0)
|
||||
|
||||
if idx == 0:
|
||||
new_sound = amplified_sound
|
||||
else:
|
||||
new_sound = new_sound.append(amplified_sound)
|
||||
|
||||
if idx < len(nonsilent_ranges) - 1:
|
||||
new_sound = new_sound.append(sound[end_i:nonsilent_ranges[idx+1][0]])
|
||||
|
||||
return new_sound.export("out.mp3", format="mp3")
|
||||
|
||||
def generate_audio_response(text, speaker_id):
|
||||
global global_config
|
||||
|
||||
model_name = os.path.basename(global_config.load_path)
|
||||
hashed_text = hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
relative_dir_path = os.path.join(AUDIO_DIR, model_name)
|
||||
relative_audio_path = os.path.join(
|
||||
relative_dir_path, "{}.{}.wav".format(hashed_text, speaker_id))
|
||||
real_path = os.path.join(ROOT_PATH, relative_audio_path)
|
||||
makedirs(os.path.dirname(real_path))
|
||||
|
||||
if not os.path.exists(add_postfix(real_path, 0)):
|
||||
try:
|
||||
audio = synthesizer.synthesize(
|
||||
[text], paths=[real_path], speaker_ids=[speaker_id],
|
||||
attention_trim=True)[0]
|
||||
except:
|
||||
return jsonify(success=False), 400
|
||||
|
||||
return send_file(
|
||||
add_postfix(relative_audio_path, 0),
|
||||
mimetype="audio/wav",
|
||||
as_attachment=True,
|
||||
attachment_filename=hashed_text + ".wav")
|
||||
|
||||
response = make_response(audio)
|
||||
response.headers['Content-Type'] = 'audio/wav'
|
||||
response.headers['Content-Disposition'] = 'attachment; filename=sound.wav'
|
||||
return response
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
text = request.args.get('text') or "듣고 싶은 문장을 입력해 주세요."
|
||||
return render_template('index.html', text=text)
|
||||
|
||||
@app.route('/generate')
|
||||
def view_method():
|
||||
text = request.args.get('text')
|
||||
speaker_id = int(request.args.get('speaker_id'))
|
||||
|
||||
if text:
|
||||
return generate_audio_response(text, speaker_id)
|
||||
else:
|
||||
return {}
|
||||
|
||||
@app.route('/js/<path:path>')
|
||||
def send_js(path):
|
||||
return send_from_directory(
|
||||
os.path.join(static_path, 'js'), path)
|
||||
|
||||
@app.route('/css/<path:path>')
|
||||
def send_css(path):
|
||||
return send_from_directory(
|
||||
os.path.join(static_path, 'css'), path)
|
||||
|
||||
@app.route('/audio/<path:path>')
|
||||
def send_audio(path):
|
||||
return send_from_directory(
|
||||
os.path.join(static_path, 'audio'), path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--load_path', required=True)
|
||||
parser.add_argument('--checkpoint_step', default=None, type=int)
|
||||
parser.add_argument('--num_speakers', default=1, type=int)
|
||||
parser.add_argument('--port', default=5000, type=int)
|
||||
parser.add_argument('--debug', default=False, type=str2bool)
|
||||
config = parser.parse_args()
|
||||
|
||||
if os.path.exists(config.load_path):
|
||||
prepare_dirs(config, hparams)
|
||||
|
||||
global_config = config
|
||||
synthesizer.load(config.load_path, config.num_speakers, config.checkpoint_step)
|
||||
else:
|
||||
print(" [!] load_path not found: {}".format(config.load_path))
|
||||
|
||||
app.run(host='0.0.0.0', port=config.port, debug=config.debug)
|
BIN
assets/model.png
Normal file
BIN
assets/model.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 625 KiB |
168
audio/__init__.py
Normal file
168
audio/__init__.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
# Code based on https://github.com/keithito/tacotron/blob/master/util/audio.py
|
||||
import math
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from scipy import signal
|
||||
from hparams import hparams
|
||||
|
||||
import librosa
|
||||
import librosa.filters
|
||||
|
||||
|
||||
def load_audio(path, pre_silence_length=0, post_silence_length=0):
|
||||
audio = librosa.core.load(path, sr=hparams.sample_rate)[0]
|
||||
if pre_silence_length > 0 or post_silence_length > 0:
|
||||
audio = np.concatenate([
|
||||
get_silence(pre_silence_length),
|
||||
audio,
|
||||
get_silence(post_silence_length),
|
||||
])
|
||||
return audio
|
||||
|
||||
def save_audio(audio, path, sample_rate=None):
|
||||
audio *= 32767 / max(0.01, np.max(np.abs(audio)))
|
||||
librosa.output.write_wav(path, audio.astype(np.int16),
|
||||
hparams.sample_rate if sample_rate is None else sample_rate)
|
||||
|
||||
print(" [*] Audio saved: {}".format(path))
|
||||
|
||||
|
||||
def resample_audio(audio, target_sample_rate):
|
||||
return librosa.core.resample(
|
||||
audio, hparams.sample_rate, target_sample_rate)
|
||||
|
||||
|
||||
def get_duration(audio):
|
||||
return librosa.core.get_duration(audio, sr=hparams.sample_rate)
|
||||
|
||||
|
||||
def frames_to_hours(n_frames):
|
||||
return sum((n_frame for n_frame in n_frames)) * \
|
||||
hparams.frame_shift_ms / (3600 * 1000)
|
||||
|
||||
|
||||
def get_silence(sec):
|
||||
return np.zeros(hparams.sample_rate * sec)
|
||||
|
||||
|
||||
def spectrogram(y):
|
||||
D = _stft(_preemphasis(y))
|
||||
S = _amp_to_db(np.abs(D)) - hparams.ref_level_db
|
||||
return _normalize(S)
|
||||
|
||||
|
||||
def inv_spectrogram(spectrogram):
|
||||
S = _db_to_amp(_denormalize(spectrogram) + hparams.ref_level_db) # Convert back to linear
|
||||
return inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase
|
||||
|
||||
|
||||
def inv_spectrogram_tensorflow(spectrogram):
|
||||
S = _db_to_amp_tensorflow(_denormalize_tensorflow(spectrogram) + hparams.ref_level_db)
|
||||
return _griffin_lim_tensorflow(tf.pow(S, hparams.power))
|
||||
|
||||
|
||||
def melspectrogram(y):
|
||||
D = _stft(_preemphasis(y))
|
||||
S = _amp_to_db(_linear_to_mel(np.abs(D)))
|
||||
return _normalize(S)
|
||||
|
||||
|
||||
def inv_melspectrogram(melspectrogram):
|
||||
S = _mel_to_linear(_db_to_amp(_denormalize(melspectrogram))) # Convert back to linear
|
||||
return inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase
|
||||
|
||||
|
||||
# Based on https://github.com/librosa/librosa/issues/434
|
||||
def _griffin_lim(S):
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
S_complex = np.abs(S).astype(np.complex)
|
||||
|
||||
y = _istft(S_complex * angles)
|
||||
for i in range(hparams.griffin_lim_iters):
|
||||
angles = np.exp(1j * np.angle(_stft(y)))
|
||||
y = _istft(S_complex * angles)
|
||||
return y
|
||||
|
||||
|
||||
def _griffin_lim_tensorflow(S):
|
||||
with tf.variable_scope('griffinlim'):
|
||||
S = tf.expand_dims(S, 0)
|
||||
S_complex = tf.identity(tf.cast(S, dtype=tf.complex64))
|
||||
y = _istft_tensorflow(S_complex)
|
||||
for i in range(hparams.griffin_lim_iters):
|
||||
est = _stft_tensorflow(y)
|
||||
angles = est / tf.cast(tf.maximum(1e-8, tf.abs(est)), tf.complex64)
|
||||
y = _istft_tensorflow(S_complex * angles)
|
||||
return tf.squeeze(y, 0)
|
||||
|
||||
|
||||
def _stft(y):
|
||||
n_fft, hop_length, win_length = _stft_parameters()
|
||||
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
|
||||
def _istft(y):
|
||||
_, hop_length, win_length = _stft_parameters()
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
|
||||
def _stft_tensorflow(signals):
|
||||
n_fft, hop_length, win_length = _stft_parameters()
|
||||
return tf.contrib.signal.stft(signals, win_length, hop_length, n_fft, pad_end=False)
|
||||
|
||||
|
||||
def _istft_tensorflow(stfts):
|
||||
n_fft, hop_length, win_length = _stft_parameters()
|
||||
return tf.contrib.signal.inverse_stft(stfts, win_length, hop_length, n_fft)
|
||||
|
||||
def _stft_parameters():
|
||||
n_fft = (hparams.num_freq - 1) * 2
|
||||
hop_length = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
|
||||
win_length = int(hparams.frame_length_ms / 1000 * hparams.sample_rate)
|
||||
return n_fft, hop_length, win_length
|
||||
|
||||
|
||||
# Conversions:
|
||||
|
||||
_mel_basis = None
|
||||
_inv_mel_basis = None
|
||||
|
||||
def _linear_to_mel(spectrogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis()
|
||||
return np.dot(_mel_basis, spectrogram)
|
||||
|
||||
def _mel_to_linear(mel_spectrogram):
|
||||
global _inv_mel_basis
|
||||
if _inv_mel_basis is None:
|
||||
_inv_mel_basis = np.linalg.pinv(_build_mel_basis())
|
||||
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
|
||||
|
||||
def _build_mel_basis():
|
||||
n_fft = (hparams.num_freq - 1) * 2
|
||||
return librosa.filters.mel(hparams.sample_rate, n_fft, n_mels=hparams.num_mels)
|
||||
|
||||
def _amp_to_db(x):
|
||||
return 20 * np.log10(np.maximum(1e-5, x))
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, x * 0.05)
|
||||
|
||||
def _db_to_amp_tensorflow(x):
|
||||
return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05)
|
||||
|
||||
def _preemphasis(x):
|
||||
return signal.lfilter([1, -hparams.preemphasis], [1], x)
|
||||
|
||||
def inv_preemphasis(x):
|
||||
return signal.lfilter([1], [1, -hparams.preemphasis], x)
|
||||
|
||||
def _normalize(S):
|
||||
return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)
|
||||
|
||||
def _denormalize(S):
|
||||
return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
|
||||
|
||||
def _denormalize_tensorflow(S):
|
||||
return (tf.clip_by_value(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
|
71
audio/get_duration.py
Normal file
71
audio/get_duration.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import os
|
||||
import datetime
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
from tinytag import TinyTag
|
||||
from collections import defaultdict
|
||||
from multiprocessing.dummy import Pool
|
||||
|
||||
from utils import load_json
|
||||
|
||||
def second_to_hour(sec):
|
||||
return str(datetime.timedelta(seconds=int(sec)))
|
||||
|
||||
def get_duration(path):
|
||||
filename = os.path.basename(path)
|
||||
candidates = filename.split('.')[0].split('_')
|
||||
dataset = candidates[0]
|
||||
|
||||
if not os.path.exists(path):
|
||||
print(" [!] {} not found".format(path))
|
||||
return dataset, 0
|
||||
|
||||
if True: # tinytag
|
||||
tag = TinyTag.get(path)
|
||||
duration = tag.duration
|
||||
else: # librosa
|
||||
y, sr = librosa.load(path)
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
|
||||
return dataset, duration
|
||||
|
||||
def get_durations(paths, print_detail=True):
|
||||
duration_all = 0
|
||||
duration_book = defaultdict(list)
|
||||
|
||||
pool = Pool()
|
||||
iterator = pool.imap_unordered(get_duration, paths)
|
||||
for dataset, duration in tqdm(iterator, total=len(paths)):
|
||||
duration_all += duration
|
||||
duration_book[dataset].append(duration)
|
||||
|
||||
total_count = 0
|
||||
for book, duration in duration_book.items():
|
||||
if book:
|
||||
time = second_to_hour(sum(duration))
|
||||
file_count = len(duration)
|
||||
total_count += file_count
|
||||
|
||||
if print_detail:
|
||||
print(" [*] Duration of {}: {} (file #: {})". \
|
||||
format(book, time, file_count))
|
||||
|
||||
print(" [*] Total Duration : {} (file #: {})". \
|
||||
format(second_to_hour(duration_all), total_count))
|
||||
print()
|
||||
return duration_all
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--audio-pattern', default=None) # datasets/krbook/audio/*.wav
|
||||
parser.add_argument('--data-path', default=None) # datasets/jtbc/alignment.json
|
||||
config, unparsed = parser.parse_known_args()
|
||||
|
||||
if config.audio_pattern is not None:
|
||||
duration = get_durations(get_paths_by_pattern(config.data_dir))
|
||||
elif config.data_path is not None:
|
||||
paths = load_json(config.data_path).keys()
|
||||
duration = get_durations(paths)
|
520
audio/google_speech.py
Normal file
520
audio/google_speech.py
Normal file
|
@ -0,0 +1,520 @@
|
|||
import io
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import string
|
||||
import argparse
|
||||
import operator
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
from nltk import ngrams
|
||||
from difflib import SequenceMatcher
|
||||
from collections import defaultdict
|
||||
|
||||
from google.cloud import speech
|
||||
from google.cloud.speech import enums
|
||||
from google.cloud.speech import types
|
||||
|
||||
from utils import parallel_run
|
||||
from text import text_to_sequence
|
||||
|
||||
####################################################
|
||||
# When one or two audio is missed in the middle
|
||||
####################################################
|
||||
|
||||
def get_continuous_audio_paths(paths, debug=False):
|
||||
audio_ids = get_audio_ids_from_paths(paths)
|
||||
min_id, max_id = min(audio_ids), max(audio_ids)
|
||||
|
||||
if int(max_id) - int(min_id) + 1 != len(audio_ids):
|
||||
base_path = paths[0].replace(min_id, "{:0" + str(len(max_id)) + "d}")
|
||||
new_paths = [
|
||||
base_path.format(audio_id) \
|
||||
for audio_id in range(int(min_id), int(max_id) + 1)]
|
||||
|
||||
if debug: print("Missing audio : {} -> {}".format(paths, new_paths))
|
||||
return new_paths
|
||||
else:
|
||||
return paths
|
||||
|
||||
def get_argmax_key(info, with_value=False):
|
||||
max_key = max(info.keys(), key=(lambda k: info[k]))
|
||||
|
||||
if with_value:
|
||||
return max_key, info[max_key]
|
||||
else:
|
||||
return max_key
|
||||
|
||||
def similarity(text_a, text_b):
|
||||
text_a = "".join(remove_puncuations(text_a.strip()).split())
|
||||
text_b = "".join(remove_puncuations(text_b.strip()).split())
|
||||
|
||||
score = SequenceMatcher(None, text_a, text_b).ratio()
|
||||
#score = 1 / (distance(decompose_ko_text(text_a), decompose_ko_text(text_b)) + 1e-5)
|
||||
#score = SequenceMatcher(None,
|
||||
# decompose_ko_text(text_a), decompose_ko_text(text_b)).ratio()
|
||||
|
||||
if len(text_a) < len(text_b):
|
||||
return -1 + score
|
||||
else:
|
||||
return score
|
||||
|
||||
def get_key_value_sorted(data):
|
||||
keys = list(data.keys())
|
||||
keys.sort()
|
||||
values = [data[key] for key in keys]
|
||||
return keys, values
|
||||
|
||||
def replace_pred_with_book(
|
||||
path, book_path=None, threshold=0.9, max_candidate_num=5,
|
||||
min_post_char_check=2, max_post_char_check=7, max_n=5,
|
||||
max_allow_missing_when_matching=4, debug=False):
|
||||
|
||||
#######################################
|
||||
# find text book from pred
|
||||
#######################################
|
||||
|
||||
if book_path is None:
|
||||
book_path = path.replace("speech", "text").replace("json", "txt")
|
||||
|
||||
data = json.loads(open(path).read())
|
||||
|
||||
keys, preds = get_key_value_sorted(data)
|
||||
|
||||
book_words = [word for word in open(book_path).read().split() if word != "=="]
|
||||
book_texts = [text.replace('\n', '') for text in open(book_path).readlines()]
|
||||
|
||||
loc = 0
|
||||
prev_key = None
|
||||
force_stop = False
|
||||
prev_end_loc = -1
|
||||
prev_sentence_ended = True
|
||||
|
||||
prev_empty_skip = False
|
||||
prev_not_found_skip = False
|
||||
|
||||
black_lists = ["160.{:04d}".format(audio_id) for audio_id in range(20, 36)]
|
||||
|
||||
new_preds = {}
|
||||
for key, pred in zip(keys, preds):
|
||||
if debug: print(key, pred)
|
||||
|
||||
if pred == "" or key in black_lists:
|
||||
prev_empty_skip = True
|
||||
continue
|
||||
|
||||
width, counter = 1, 0
|
||||
sim_dict, loc_dict = {}, {}
|
||||
|
||||
while True:
|
||||
words = book_words[loc:loc + width]
|
||||
|
||||
if len(words) == 0:
|
||||
print("Force stop. Left {}, Del {} {}". \
|
||||
format(len(preds) - len(new_preds), new_preds[prev_key], prev_key))
|
||||
new_preds.pop(prev_key, None)
|
||||
force_stop = True
|
||||
break
|
||||
|
||||
candidate_candidates = {}
|
||||
|
||||
for _pred in list(set([pred, koreanize_numbers(pred)])):
|
||||
max_skip = 0 if has_number(_pred[0]) or \
|
||||
_pred[0] in """"'“”’‘’""" else len(words)
|
||||
|
||||
end_sims = []
|
||||
for idx in range(min(max_skip, 10)):
|
||||
text = " ".join(words[idx:])
|
||||
|
||||
################################################
|
||||
# Score of trailing sentence is also important
|
||||
################################################
|
||||
|
||||
for jdx in range(min_post_char_check,
|
||||
max_post_char_check):
|
||||
sim = similarity(
|
||||
"".join(_pred.split())[-jdx:],
|
||||
"".join(text.split())[-jdx:])
|
||||
end_sims.append(sim)
|
||||
|
||||
candidate_candidates[text] = similarity(_pred, text)
|
||||
|
||||
candidate, sim = get_argmax_key(
|
||||
candidate_candidates, with_value=True)
|
||||
|
||||
if sim > threshold or max(end_sims + [-1]) > threshold - 0.2 or \
|
||||
len(sim_dict) > 0:
|
||||
sim_dict[candidate] = sim
|
||||
loc_dict[candidate] = loc + width
|
||||
|
||||
if len(sim_dict) > 0:
|
||||
counter += 1
|
||||
|
||||
if counter > max_candidate_num:
|
||||
break
|
||||
|
||||
width += 1
|
||||
|
||||
if width - len(_pred.split()) > 5:
|
||||
break
|
||||
|
||||
if force_stop:
|
||||
break
|
||||
|
||||
if len(sim_dict) != 0:
|
||||
#############################################################
|
||||
# Check missing words between prev pred and current pred
|
||||
#############################################################
|
||||
|
||||
if prev_key is not None:
|
||||
cur_idx = int(key.rsplit('.', 2)[-2])
|
||||
prev_idx = int(prev_key.rsplit('.', 2)[-2])
|
||||
|
||||
if cur_idx - prev_idx > 10:
|
||||
force_stop = True
|
||||
break
|
||||
|
||||
# word alinged based on prediction but may contain missing words
|
||||
# because google speech recognition sometimes skip one or two word
|
||||
# ex. ('오누이는 서로 자기가 할 일을 정했다.', '서로 자기가 할 일을 정했다.')
|
||||
original_candidate = new_candidate = get_argmax_key(sim_dict)
|
||||
|
||||
word_to_find = original_candidate.split()[0]
|
||||
|
||||
if not prev_empty_skip:
|
||||
search_idx = book_words[prev_end_loc:].index(word_to_find) \
|
||||
if word_to_find in book_words[prev_end_loc:] else -1
|
||||
|
||||
if 0 < search_idx < 4 and not prev_sentence_ended:
|
||||
words_to_check = book_words[prev_end_loc:prev_end_loc + search_idx]
|
||||
|
||||
if ends_with_punctuation(words_to_check[0]) == True:
|
||||
tmp = " ".join([new_preds[prev_key]] + words_to_check[:1])
|
||||
if debug: print(prev_key, tmp, new_preds[prev_key])
|
||||
new_preds[prev_key] = tmp
|
||||
|
||||
prev_end_loc += 1
|
||||
prev_sentence_ended = True
|
||||
|
||||
search_idx = book_words[prev_end_loc:].index(word_to_find) \
|
||||
if word_to_find in book_words[prev_end_loc:] else -1
|
||||
|
||||
if 0 < search_idx < 4 and prev_sentence_ended:
|
||||
words_to_check = book_words[prev_end_loc:prev_end_loc + search_idx]
|
||||
|
||||
if not any(ends_with_punctuation(word) for word in words_to_check):
|
||||
new_candidate = " ".join(words_to_check + [original_candidate])
|
||||
if debug: print(key, new_candidate, original_candidate)
|
||||
|
||||
new_preds[key] = new_candidate
|
||||
prev_sentence_ended = ends_with_punctuation(new_candidate)
|
||||
|
||||
loc = loc_dict[original_candidate]
|
||||
prev_key = key
|
||||
prev_not_found_skip = False
|
||||
else:
|
||||
loc += len(_pred.split()) - 1
|
||||
prev_sentence_ended = True
|
||||
prev_not_found_skip = True
|
||||
|
||||
prev_end_loc = loc
|
||||
prev_empty_skip = False
|
||||
|
||||
if debug:
|
||||
print("=", pred)
|
||||
print("=", new_preds[key], loc)
|
||||
|
||||
if force_stop:
|
||||
print(" [!] Force stop: {}".format(path))
|
||||
|
||||
align_diff = loc - len(book_words)
|
||||
|
||||
if abs(align_diff) > 10:
|
||||
print(" => Align result of {}: {} - {} = {}".format(path, loc, len(book_words), align_diff))
|
||||
|
||||
#######################################
|
||||
# find exact match of n-gram of pred
|
||||
#######################################
|
||||
|
||||
finished_ids = []
|
||||
|
||||
keys, preds = get_key_value_sorted(new_preds)
|
||||
|
||||
if abs(align_diff) > 10:
|
||||
keys, preds = keys[:-30], preds[:-30]
|
||||
|
||||
unfinished_ids = range(len(keys))
|
||||
text_matches = []
|
||||
|
||||
for n in range(max_n, 1, -1):
|
||||
ngram_preds = ngrams(preds, n)
|
||||
|
||||
for n_allow_missing in range(0, max_allow_missing_when_matching + 1):
|
||||
unfinished_ids = list(set(unfinished_ids) - set(finished_ids))
|
||||
|
||||
existing_ngram_preds = []
|
||||
|
||||
for ngram in ngram_preds:
|
||||
for text in book_texts:
|
||||
candidates = [
|
||||
" ".join(text.split()[:-n_allow_missing]),
|
||||
" ".join(text.split()[n_allow_missing:]),
|
||||
]
|
||||
for tmp_text in candidates:
|
||||
if " ".join(ngram) == tmp_text:
|
||||
existing_ngram_preds.append(ngram)
|
||||
break
|
||||
|
||||
tmp_keys = []
|
||||
cur_ngram = []
|
||||
|
||||
ngram_idx = 0
|
||||
ngram_found = False
|
||||
|
||||
for id_idx in unfinished_ids:
|
||||
key, pred = keys[id_idx], preds[id_idx]
|
||||
|
||||
if ngram_idx >= len(existing_ngram_preds):
|
||||
break
|
||||
|
||||
cur_ngram = existing_ngram_preds[ngram_idx]
|
||||
|
||||
if pred in cur_ngram:
|
||||
ngram_found = True
|
||||
|
||||
tmp_keys.append(key)
|
||||
finished_ids.append(id_idx)
|
||||
|
||||
if len(tmp_keys) == len(cur_ngram):
|
||||
if debug: print(n_allow_missing, tmp_keys, cur_ngram)
|
||||
|
||||
tmp_keys = get_continuous_audio_paths(tmp_keys, debug)
|
||||
text_matches.append(
|
||||
[[" ".join(cur_ngram)], tmp_keys]
|
||||
)
|
||||
|
||||
ngram_idx += 1
|
||||
tmp_keys = []
|
||||
cur_ngram = []
|
||||
else:
|
||||
if pred == cur_ngram[-1]:
|
||||
ngram_idx += 1
|
||||
tmp_keys = []
|
||||
cur_ngram = []
|
||||
else:
|
||||
if len(tmp_keys) > 0:
|
||||
ngram_found = False
|
||||
|
||||
tmp_keys = []
|
||||
cur_ngram = []
|
||||
|
||||
for id_idx in range(len(keys)):
|
||||
if id_idx not in finished_ids:
|
||||
key, pred = keys[id_idx], preds[id_idx]
|
||||
|
||||
text_matches.append(
|
||||
[[pred], [key]]
|
||||
)
|
||||
|
||||
##############################################################
|
||||
# ngram again for just in case after adding missing words
|
||||
##############################################################
|
||||
|
||||
max_keys = [max(get_audio_ids_from_paths(item[1], as_int=True)) for item in text_matches]
|
||||
sorted_text_matches = \
|
||||
[item for _, item in sorted(zip(max_keys, text_matches))]
|
||||
|
||||
preds = [item[0][0] for item in sorted_text_matches]
|
||||
keys = [item[1] for item in sorted_text_matches]
|
||||
|
||||
def book_sentence_idx_search(query, book_texts):
|
||||
for idx, text in enumerate(book_texts):
|
||||
if query in text:
|
||||
return idx, text
|
||||
return False, False
|
||||
|
||||
text_matches = []
|
||||
idx, book_cursor_idx = 0, 0
|
||||
|
||||
if len(preds) == 0:
|
||||
return []
|
||||
|
||||
while True:
|
||||
tmp_texts = book_texts[book_cursor_idx:]
|
||||
|
||||
jdx = 0
|
||||
tmp_pred = preds[idx]
|
||||
idxes_to_merge = [idx]
|
||||
|
||||
prev_sent_idx, prev_sent = book_sentence_idx_search(tmp_pred, tmp_texts)
|
||||
while idx + jdx + 1 < len(preds):
|
||||
jdx += 1
|
||||
|
||||
tmp_pred = preds[idx + jdx]
|
||||
sent_idx, sent = book_sentence_idx_search(tmp_pred, tmp_texts)
|
||||
|
||||
if not sent_idx:
|
||||
if debug: print(" [!] NOT FOUND: {}".format(tmp_pred))
|
||||
break
|
||||
|
||||
if prev_sent_idx == sent_idx:
|
||||
idxes_to_merge.append(idx + jdx)
|
||||
else:
|
||||
break
|
||||
|
||||
new_keys = get_continuous_audio_paths(
|
||||
sum([keys[jdx] for jdx in idxes_to_merge], []))
|
||||
text_matches.append([ [tmp_texts[prev_sent_idx]], new_keys ])
|
||||
|
||||
if len(new_keys) > 1:
|
||||
book_cursor_idx += 1
|
||||
|
||||
book_cursor_idx = max(book_cursor_idx, sent_idx)
|
||||
|
||||
if idx == len(preds) - 1:
|
||||
break
|
||||
idx = idx + jdx
|
||||
|
||||
# Counter([len(i) for i in text_matches.values()])
|
||||
return text_matches
|
||||
|
||||
def get_text_from_audio_batch(paths, multi_process=False):
|
||||
results = {}
|
||||
items = parallel_run(get_text_from_audio, paths,
|
||||
desc="get_text_from_audio_batch")
|
||||
for item in items:
|
||||
results.update(item)
|
||||
return results
|
||||
|
||||
def get_text_from_audio(path):
|
||||
error_count = 0
|
||||
|
||||
txt_path = path.replace('flac', 'txt')
|
||||
|
||||
if os.path.exists(txt_path):
|
||||
with open(txt_path) as f:
|
||||
out = json.loads(open(txt_path).read())
|
||||
return out
|
||||
|
||||
out = {}
|
||||
while True:
|
||||
try:
|
||||
client = speech.SpeechClient()
|
||||
|
||||
with io.open(path, 'rb') as audio_file:
|
||||
content = audio_file.read()
|
||||
audio = types.RecognitionAudio(content=content)
|
||||
|
||||
config = types.RecognitionConfig(
|
||||
encoding=enums.RecognitionConfig.AudioEncoding.FLAC,
|
||||
sample_rate_hertz=16000,
|
||||
language_code='ko-KR')
|
||||
|
||||
response = client.recognize(config, audio)
|
||||
if len(response.results) > 0:
|
||||
alternatives = response.results[0].alternatives
|
||||
|
||||
results = [alternative.transcript for alternative in alternatives]
|
||||
assert len(results) == 1, "More than 1 results: {}".format(results)
|
||||
|
||||
out = { path: "" if len(results) == 0 else results[0] }
|
||||
print(results[0])
|
||||
break
|
||||
break
|
||||
except:
|
||||
error_count += 1
|
||||
print("Skip warning for {} for {} times". \
|
||||
format(path, error_count))
|
||||
|
||||
if error_count > 5:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
with open(txt_path, 'w') as f:
|
||||
json.dump(out, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--asset-dir', type=str, default='assets')
|
||||
parser.add_argument('--data-dir', type=str, default='audio')
|
||||
parser.add_argument('--pattern', type=str, default="audio/*.flac")
|
||||
parser.add_argument('--metadata', type=str, default="metadata.json")
|
||||
config, unparsed = parser.parse_known_args()
|
||||
|
||||
paths = glob(config.pattern)
|
||||
paths.sort()
|
||||
paths = paths
|
||||
|
||||
book_ids = list(set([
|
||||
os.path.basename(path).split('.', 1)[0] for path in paths]))
|
||||
book_ids.sort()
|
||||
|
||||
def get_finished_ids():
|
||||
finished_paths = glob(os.path.join(
|
||||
config.asset_dir, "speech-*.json"))
|
||||
finished_ids = list(set([
|
||||
os.path.basename(path).split('.', 1)[0].replace("speech-", "") for path in finished_paths]))
|
||||
finished_ids.sort()
|
||||
return finished_ids
|
||||
|
||||
finished_ids = get_finished_ids()
|
||||
|
||||
print("# Finished : {}/{}".format(len(finished_ids), len(book_ids)))
|
||||
|
||||
book_ids_to_parse = list(set(book_ids) - set(finished_ids))
|
||||
book_ids_to_parse.sort()
|
||||
|
||||
assert os.path.exists(config.asset_dir), "assert_dir not found"
|
||||
|
||||
pbar = tqdm(book_ids_to_parse, "[1] google_speech",
|
||||
initial=len(finished_ids), total=len(book_ids))
|
||||
|
||||
for book_id in pbar:
|
||||
current_paths = glob(config.pattern.replace("*", "{}.*".format(book_id)))
|
||||
pbar.set_description("[1] google_speech : {}".format(book_id))
|
||||
|
||||
results = get_text_from_audio_batch(current_paths)
|
||||
|
||||
filename = "speech-{}.json".format(book_id)
|
||||
path = os.path.join(config.asset_dir, filename)
|
||||
|
||||
with open(path, "w") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
finished_ids = get_finished_ids()
|
||||
|
||||
for book_id in tqdm(finished_ids, "[2] text_match"):
|
||||
filename = "speech-{}.json".format(book_id)
|
||||
path = os.path.join(config.asset_dir, filename)
|
||||
clean_path = path.replace("speech", "clean-speech")
|
||||
|
||||
if os.path.exists(clean_path):
|
||||
print(" [*] Skip {}".format(clean_path))
|
||||
else:
|
||||
results = replace_pred_with_book(path)
|
||||
with open(clean_path, "w") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Dummy
|
||||
|
||||
if False:
|
||||
match_paths = get_paths_by_pattern(
|
||||
config.asset_dir, 'clean-speech-*.json')
|
||||
|
||||
metadata_path = os.path.join(config.data_dir, config.metadata)
|
||||
|
||||
print(" [3] Merge clean-speech-*.json into {}".format(metadata_path))
|
||||
|
||||
merged_data = []
|
||||
for path in match_paths:
|
||||
with open(path) as f:
|
||||
merged_data.extend(json.loads(f.read()))
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(merged_data, f, indent=2, ensure_ascii=False)
|
143
audio/silence.py
Normal file
143
audio/silence.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import librosa
|
||||
import argparse
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from glob import glob
|
||||
from pydub import silence
|
||||
from pydub import AudioSegment
|
||||
from functools import partial
|
||||
|
||||
from hparams import hparams
|
||||
from utils import parallel_run, add_postfix
|
||||
from audio import load_audio, save_audio, get_duration, get_silence
|
||||
|
||||
def abs_mean(x):
|
||||
return abs(x).mean()
|
||||
|
||||
def remove_breath(audio):
|
||||
edges = librosa.effects.split(
|
||||
audio, top_db=40, frame_length=128, hop_length=32)
|
||||
|
||||
for idx in range(len(edges)):
|
||||
start_idx, end_idx = edges[idx][0], edges[idx][1]
|
||||
if start_idx < len(audio):
|
||||
if abs_mean(audio[start_idx:end_idx]) < abs_mean(audio) - 0.05:
|
||||
audio[start_idx:end_idx] = 0
|
||||
|
||||
return audio
|
||||
|
||||
def split_on_silence_with_librosa(
|
||||
audio_path, top_db=40, frame_length=1024, hop_length=256,
|
||||
skip_idx=0, out_ext="wav",
|
||||
min_segment_length=3, max_segment_length=8,
|
||||
pre_silence_length=0, post_silence_length=0):
|
||||
|
||||
filename = os.path.basename(audio_path).split('.', 1)[0]
|
||||
in_ext = audio_path.rsplit(".")[1]
|
||||
|
||||
audio = load_audio(audio_path)
|
||||
|
||||
edges = librosa.effects.split(audio,
|
||||
top_db=top_db, frame_length=frame_length, hop_length=hop_length)
|
||||
|
||||
new_audio = np.zeros_like(audio)
|
||||
for idx, (start, end) in enumerate(edges[skip_idx:]):
|
||||
new_audio[start:end] = remove_breath(audio[start:end])
|
||||
|
||||
save_audio(new_audio, add_postfix(audio_path, "no_breath"))
|
||||
audio = new_audio
|
||||
edges = librosa.effects.split(audio,
|
||||
top_db=top_db, frame_length=frame_length, hop_length=hop_length)
|
||||
|
||||
audio_paths = []
|
||||
for idx, (start, end) in enumerate(edges[skip_idx:]):
|
||||
segment = audio[start:end]
|
||||
duration = get_duration(segment)
|
||||
|
||||
if duration <= min_segment_length or duration >= max_segment_length:
|
||||
continue
|
||||
|
||||
output_path = "{}/{}.{:04d}.{}".format(
|
||||
os.path.dirname(audio_path), filename, idx, out_ext)
|
||||
|
||||
padded_segment = np.concatenate([
|
||||
get_silence(pre_silence_length),
|
||||
segment,
|
||||
get_silence(post_silence_length),
|
||||
])
|
||||
|
||||
save_audio(padded_segment, output_path)
|
||||
audio_paths.append(output_path)
|
||||
|
||||
return audio_paths
|
||||
|
||||
def read_audio(audio_path):
|
||||
return AudioSegment.from_file(audio_path)
|
||||
|
||||
def split_on_silence_with_pydub(
|
||||
audio_path, skip_idx=0, out_ext="wav",
|
||||
silence_thresh=-40, min_silence_len=400,
|
||||
silence_chunk_len=100, keep_silence=100):
|
||||
|
||||
filename = os.path.basename(audio_path).split('.', 1)[0]
|
||||
in_ext = audio_path.rsplit(".")[1]
|
||||
|
||||
audio = read_audio(audio_path)
|
||||
not_silence_ranges = silence.detect_nonsilent(
|
||||
audio, min_silence_len=silence_chunk_len,
|
||||
silence_thresh=silence_thresh)
|
||||
|
||||
edges = [not_silence_ranges[0]]
|
||||
|
||||
for idx in range(1, len(not_silence_ranges)-1):
|
||||
cur_start = not_silence_ranges[idx][0]
|
||||
prev_end = edges[-1][1]
|
||||
|
||||
if cur_start - prev_end < min_silence_len:
|
||||
edges[-1][1] = not_silence_ranges[idx][1]
|
||||
else:
|
||||
edges.append(not_silence_ranges[idx])
|
||||
|
||||
audio_paths = []
|
||||
for idx, (start_idx, end_idx) in enumerate(edges[skip_idx:]):
|
||||
start_idx = max(0, start_idx - keep_silence)
|
||||
end_idx += keep_silence
|
||||
|
||||
target_audio_path = "{}/{}.{:04d}.{}".format(
|
||||
os.path.dirname(audio_path), filename, idx, out_ext)
|
||||
|
||||
audio[start_idx:end_idx].export(target_audio_path, out_ext)
|
||||
|
||||
audio_paths.append(target_audio_path)
|
||||
|
||||
return audio_paths
|
||||
|
||||
def split_on_silence_batch(audio_paths, method, **kargv):
|
||||
audio_paths.sort()
|
||||
method = method.lower()
|
||||
|
||||
if method == "librosa":
|
||||
fn = partial(split_on_silence_with_librosa, **kargv)
|
||||
elif method == "pydub":
|
||||
fn = partial(split_on_silence_with_pydub, **kargv)
|
||||
|
||||
parallel_run(fn, audio_paths,
|
||||
desc="Split on silence", parallel=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--audio_pattern', required=True)
|
||||
parser.add_argument('--out_ext', default='wav')
|
||||
parser.add_argument('--method', choices=['librosa', 'pydub'], required=True)
|
||||
config = parser.parse_args()
|
||||
|
||||
audio_paths = glob(config.audio_pattern)
|
||||
|
||||
split_on_silence_batch(
|
||||
audio_paths, config.method,
|
||||
out_ext=config.out_ext,
|
||||
)
|
0
datasets/__init__.py
Normal file
0
datasets/__init__.py
Normal file
328
datasets/datafeeder.py
Normal file
328
datasets/datafeeder.py
Normal file
|
@ -0,0 +1,328 @@
|
|||
import os
|
||||
import time
|
||||
import pprint
|
||||
import random
|
||||
import threading
|
||||
import traceback
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
import tensorflow as tf
|
||||
from collections import defaultdict
|
||||
|
||||
import text
|
||||
from utils.infolog import log
|
||||
from utils import parallel_run, remove_file
|
||||
from audio import frames_to_hours
|
||||
from audio.get_duration import get_durations
|
||||
|
||||
|
||||
_pad = 0
|
||||
|
||||
def get_frame(path):
|
||||
data = np.load(path)
|
||||
n_frame = data["linear"].shape[0]
|
||||
n_token = len(data["tokens"])
|
||||
return (path, n_frame, n_token)
|
||||
|
||||
def get_path_dict(
|
||||
data_dirs, hparams, config,
|
||||
data_type, n_test=None,
|
||||
rng=np.random.RandomState(123)):
|
||||
|
||||
# Load metadata:
|
||||
path_dict = {}
|
||||
for data_dir in data_dirs:
|
||||
paths = glob("{}/*.npz".format(data_dir))
|
||||
|
||||
if data_type == 'train':
|
||||
rng.shuffle(paths)
|
||||
|
||||
if not config.skip_path_filter:
|
||||
items = parallel_run(
|
||||
get_frame, paths, desc="filter_by_min_max_frame_batch", parallel=True)
|
||||
|
||||
min_n_frame = hparams.reduction_factor * hparams.min_iters
|
||||
max_n_frame = hparams.reduction_factor * hparams.max_iters - hparams.reduction_factor
|
||||
|
||||
new_items = [(path, n) for path, n, n_tokens in items \
|
||||
if min_n_frame <= n <= max_n_frame and n_tokens >= hparams.min_tokens]
|
||||
|
||||
if any(check in data_dir for check in ["son", "yuinna"]):
|
||||
blacklists = [".0000.", ".0001.", "NB11479580.0001"]
|
||||
new_items = [item for item in new_items \
|
||||
if any(check not in item[0] for check in blacklists)]
|
||||
|
||||
new_paths = [path for path, n in new_items]
|
||||
new_n_frames = [n for path, n in new_items]
|
||||
|
||||
hours = frames_to_hours(new_n_frames)
|
||||
|
||||
log(' [{}] Loaded metadata for {} examples ({:.2f} hours)'. \
|
||||
format(data_dir, len(new_n_frames), hours))
|
||||
log(' [{}] Max length: {}'.format(data_dir, max(new_n_frames)))
|
||||
log(' [{}] Min length: {}'.format(data_dir, min(new_n_frames)))
|
||||
else:
|
||||
new_paths = paths
|
||||
|
||||
if data_type == 'train':
|
||||
new_paths = new_paths[:-n_test]
|
||||
elif data_type == 'test':
|
||||
new_paths = new_paths[-n_test:]
|
||||
else:
|
||||
raise Exception(" [!] Unkown data_type: {}".format(data_type))
|
||||
|
||||
path_dict[data_dir] = new_paths
|
||||
|
||||
return path_dict
|
||||
|
||||
class DataFeeder(threading.Thread):
|
||||
'''Feeds batches of data into a queue on a background thread.'''
|
||||
|
||||
def __init__(self, coordinator, data_dirs,
|
||||
hparams, config, batches_per_group, data_type, batch_size):
|
||||
super(DataFeeder, self).__init__()
|
||||
|
||||
self._coord = coordinator
|
||||
self._hp = hparams
|
||||
self._cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
|
||||
self._step = 0
|
||||
self._offset = defaultdict(lambda: 2)
|
||||
self._batches_per_group = batches_per_group
|
||||
|
||||
self.rng = np.random.RandomState(config.random_seed)
|
||||
self.data_type = data_type
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.min_tokens = hparams.min_tokens
|
||||
self.min_n_frame = hparams.reduction_factor * hparams.min_iters
|
||||
self.max_n_frame = hparams.reduction_factor * hparams.max_iters - hparams.reduction_factor
|
||||
self.skip_path_filter = config.skip_path_filter
|
||||
|
||||
# Load metadata:
|
||||
self.path_dict = get_path_dict(
|
||||
data_dirs, self._hp, config, self.data_type,
|
||||
n_test=self.batch_size, rng=self.rng)
|
||||
|
||||
self.data_dirs = list(self.path_dict.keys())
|
||||
self.data_dir_to_id = {
|
||||
data_dir: idx for idx, data_dir in enumerate(self.data_dirs)}
|
||||
|
||||
data_weight = {
|
||||
data_dir: 1. for data_dir in self.data_dirs
|
||||
}
|
||||
|
||||
if self._hp.main_data_greedy_factor > 0 and \
|
||||
any(main_data in data_dir for data_dir in self.data_dirs \
|
||||
for main_data in self._hp.main_data):
|
||||
for main_data in self._hp.main_data:
|
||||
for data_dir in self.data_dirs:
|
||||
if main_data in data_dir:
|
||||
data_weight[data_dir] += self._hp.main_data_greedy_factor
|
||||
|
||||
weight_Z = sum(data_weight.values())
|
||||
self.data_ratio = {
|
||||
data_dir: weight / weight_Z for data_dir, weight in data_weight.items()
|
||||
}
|
||||
|
||||
log("="*40)
|
||||
log(pprint.pformat(self.data_ratio, indent=4))
|
||||
log("="*40)
|
||||
|
||||
#audio_paths = [path.replace("/data/", "/audio/"). \
|
||||
# replace(".npz", ".wav") for path in self.data_paths]
|
||||
#duration = get_durations(audio_paths, print_detail=False)
|
||||
|
||||
# Create placeholders for inputs and targets. Don't specify batch size because we want to
|
||||
# be able to feed different sized batches at eval time.
|
||||
|
||||
self._placeholders = [
|
||||
tf.placeholder(tf.int32, [None, None], 'inputs'),
|
||||
tf.placeholder(tf.int32, [None], 'input_lengths'),
|
||||
tf.placeholder(tf.float32, [None], 'loss_coeff'),
|
||||
tf.placeholder(tf.float32, [None, None, hparams.num_mels], 'mel_targets'),
|
||||
tf.placeholder(tf.float32, [None, None, hparams.num_freq], 'linear_targets'),
|
||||
]
|
||||
|
||||
# Create queue for buffering data:
|
||||
dtypes = [tf.int32, tf.int32, tf.float32, tf.float32, tf.float32]
|
||||
|
||||
self.is_multi_speaker = len(self.data_dirs) > 1
|
||||
|
||||
if self.is_multi_speaker:
|
||||
self._placeholders.append(
|
||||
tf.placeholder(tf.int32, [None], 'inputs'),
|
||||
)
|
||||
dtypes.append(tf.int32)
|
||||
|
||||
num_worker = 8 if self.data_type == 'train' else 1
|
||||
queue = tf.FIFOQueue(num_worker, dtypes, name='input_queue')
|
||||
|
||||
self._enqueue_op = queue.enqueue(self._placeholders)
|
||||
|
||||
if self.is_multi_speaker:
|
||||
self.inputs, self.input_lengths, self.loss_coeff, \
|
||||
self.mel_targets, self.linear_targets, self.speaker_id = queue.dequeue()
|
||||
else:
|
||||
self.inputs, self.input_lengths, self.loss_coeff, \
|
||||
self.mel_targets, self.linear_targets = queue.dequeue()
|
||||
|
||||
self.inputs.set_shape(self._placeholders[0].shape)
|
||||
self.input_lengths.set_shape(self._placeholders[1].shape)
|
||||
self.loss_coeff.set_shape(self._placeholders[2].shape)
|
||||
self.mel_targets.set_shape(self._placeholders[3].shape)
|
||||
self.linear_targets.set_shape(self._placeholders[4].shape)
|
||||
|
||||
if self.is_multi_speaker:
|
||||
self.speaker_id.set_shape(self._placeholders[5].shape)
|
||||
else:
|
||||
self.speaker_id = None
|
||||
|
||||
if self.data_type == 'test':
|
||||
examples = []
|
||||
while True:
|
||||
for data_dir in self.data_dirs:
|
||||
examples.append(self._get_next_example(data_dir))
|
||||
#print(data_dir, text.sequence_to_text(examples[-1][0], False, True))
|
||||
if len(examples) >= self.batch_size:
|
||||
break
|
||||
if len(examples) >= self.batch_size:
|
||||
break
|
||||
self.static_batches = [examples for _ in range(self._batches_per_group)]
|
||||
|
||||
else:
|
||||
self.static_batches = None
|
||||
|
||||
def start_in_session(self, session, start_step):
|
||||
self._step = start_step
|
||||
self._session = session
|
||||
self.start()
|
||||
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
while not self._coord.should_stop():
|
||||
self._enqueue_next_group()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
self._coord.request_stop(e)
|
||||
|
||||
|
||||
def _enqueue_next_group(self):
|
||||
start = time.time()
|
||||
|
||||
# Read a group of examples:
|
||||
n = self.batch_size
|
||||
r = self._hp.reduction_factor
|
||||
|
||||
if self.static_batches is not None:
|
||||
batches = self.static_batches
|
||||
else:
|
||||
examples = []
|
||||
for data_dir in self.data_dirs:
|
||||
if self._hp.initial_data_greedy:
|
||||
if self._step < self._hp.initial_phase_step and \
|
||||
any("krbook" in data_dir for data_dir in self.data_dirs):
|
||||
data_dir = [data_dir for data_dir in self.data_dirs if "krbook" in data_dir][0]
|
||||
|
||||
if self._step < self._hp.initial_phase_step:
|
||||
example = [self._get_next_example(data_dir) \
|
||||
for _ in range(int(n * self._batches_per_group // len(self.data_dirs)))]
|
||||
else:
|
||||
example = [self._get_next_example(data_dir) \
|
||||
for _ in range(int(n * self._batches_per_group * self.data_ratio[data_dir]))]
|
||||
examples.extend(example)
|
||||
examples.sort(key=lambda x: x[-1])
|
||||
|
||||
batches = [examples[i:i+n] for i in range(0, len(examples), n)]
|
||||
self.rng.shuffle(batches)
|
||||
|
||||
log('Generated %d batches of size %d in %.03f sec' % (len(batches), n, time.time() - start))
|
||||
for batch in batches:
|
||||
feed_dict = dict(zip(self._placeholders, _prepare_batch(batch, r, self.rng, self.data_type)))
|
||||
self._session.run(self._enqueue_op, feed_dict=feed_dict)
|
||||
self._step += 1
|
||||
|
||||
|
||||
def _get_next_example(self, data_dir):
|
||||
'''Loads a single example (input, mel_target, linear_target, cost) from disk'''
|
||||
data_paths = self.path_dict[data_dir]
|
||||
|
||||
while True:
|
||||
if self._offset[data_dir] >= len(data_paths):
|
||||
self._offset[data_dir] = 0
|
||||
|
||||
if self.data_type == 'train':
|
||||
self.rng.shuffle(data_paths)
|
||||
|
||||
data_path = data_paths[self._offset[data_dir]]
|
||||
self._offset[data_dir] += 1
|
||||
|
||||
try:
|
||||
if os.path.exists(data_path):
|
||||
data = np.load(data_path)
|
||||
else:
|
||||
continue
|
||||
except:
|
||||
remove_file(data_path)
|
||||
continue
|
||||
|
||||
if not self.skip_path_filter:
|
||||
break
|
||||
|
||||
if self.min_n_frame <= data["linear"].shape[0] <= self.max_n_frame and \
|
||||
len(data["tokens"]) > self.min_tokens:
|
||||
break
|
||||
|
||||
input_data = data['tokens']
|
||||
mel_target = data['mel']
|
||||
|
||||
if 'loss_coeff' in data:
|
||||
loss_coeff = data['loss_coeff']
|
||||
else:
|
||||
loss_coeff = 1
|
||||
linear_target = data['linear']
|
||||
|
||||
return (input_data, loss_coeff, mel_target, linear_target,
|
||||
self.data_dir_to_id[data_dir], len(linear_target))
|
||||
|
||||
|
||||
def _prepare_batch(batch, reduction_factor, rng, data_type=None):
|
||||
if data_type == 'train':
|
||||
rng.shuffle(batch)
|
||||
|
||||
inputs = _prepare_inputs([x[0] for x in batch])
|
||||
input_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32)
|
||||
loss_coeff = np.asarray([x[1] for x in batch], dtype=np.float32)
|
||||
|
||||
mel_targets = _prepare_targets([x[2] for x in batch], reduction_factor)
|
||||
linear_targets = _prepare_targets([x[3] for x in batch], reduction_factor)
|
||||
|
||||
if len(batch[0]) == 6:
|
||||
speaker_id = np.asarray([x[4] for x in batch], dtype=np.int32)
|
||||
return (inputs, input_lengths, loss_coeff,
|
||||
mel_targets, linear_targets, speaker_id)
|
||||
else:
|
||||
return (inputs, input_lengths, loss_coeff, mel_targets, linear_targets)
|
||||
|
||||
|
||||
def _prepare_inputs(inputs):
|
||||
max_len = max((len(x) for x in inputs))
|
||||
return np.stack([_pad_input(x, max_len) for x in inputs])
|
||||
|
||||
|
||||
def _prepare_targets(targets, alignment):
|
||||
max_len = max((len(t) for t in targets)) + 1
|
||||
return np.stack([_pad_target(t, _round_up(max_len, alignment)) for t in targets])
|
||||
|
||||
|
||||
def _pad_input(x, length):
|
||||
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||
|
||||
|
||||
def _pad_target(t, length):
|
||||
return np.pad(t, [(0, length - t.shape[0]), (0,0)], mode='constant', constant_values=_pad)
|
||||
|
||||
|
||||
def _round_up(x, multiple):
|
||||
remainder = x % multiple
|
||||
return x if remainder == 0 else x + multiple - remainder
|
191
datasets/generate_data.py
Normal file
191
datasets/generate_data.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
# Code based on https://github.com/keithito/tacotron/blob/master/datasets/ljspeech.py
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from glob import glob
|
||||
from functools import partial
|
||||
|
||||
from collections import Counter, defaultdict
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from hparams import hparams
|
||||
from text import text_to_sequence
|
||||
from utils import makedirs, remove_file, warning
|
||||
from audio import load_audio, spectrogram, melspectrogram, frames_to_hours
|
||||
|
||||
def one(x=None):
|
||||
return 1
|
||||
|
||||
def build_from_path(config):
|
||||
warning("Sampling rate: {}".format(hparams.sample_rate))
|
||||
|
||||
executor = ProcessPoolExecutor(max_workers=config.num_workers)
|
||||
futures = []
|
||||
index = 1
|
||||
|
||||
base_dir = os.path.dirname(config.metadata_path)
|
||||
data_dir = os.path.join(base_dir, config.data_dirname)
|
||||
makedirs(data_dir)
|
||||
|
||||
loss_coeff = defaultdict(one)
|
||||
if config.metadata_path.endswith("json"):
|
||||
with open(config.metadata_path) as f:
|
||||
content = f.read()
|
||||
info = json.loads(content)
|
||||
elif config.metadata_path.endswith("csv"):
|
||||
with open(config.metadata_path) as f:
|
||||
info = {}
|
||||
for line in f:
|
||||
path, text = line.strip().split('|')
|
||||
info[path] = text
|
||||
else:
|
||||
raise Exception(" [!] Unkown metadata format: {}".format(config.metadata_path))
|
||||
|
||||
new_info = {}
|
||||
for path in info.keys():
|
||||
if not os.path.exists(path):
|
||||
new_path = os.path.join(base_dir, path)
|
||||
if not os.path.exists(new_path):
|
||||
print(" [!] Audio not found: {}".format([path, new_path]))
|
||||
continue
|
||||
else:
|
||||
new_path = path
|
||||
|
||||
new_info[new_path] = info[path]
|
||||
|
||||
info = new_info
|
||||
|
||||
for path in info.keys():
|
||||
if type(info[path]) == list:
|
||||
if hparams.ignore_recognition_level == 1 and len(info[path]) == 1 or \
|
||||
hparams.ignore_recognition_level == 2:
|
||||
loss_coeff[path] = hparams.recognition_loss_coeff
|
||||
|
||||
info[path] = info[path][0]
|
||||
|
||||
ignore_description = {
|
||||
0: "use all",
|
||||
1: "ignore only unmatched_alignment",
|
||||
2: "fully ignore recognitio",
|
||||
}
|
||||
|
||||
print(" [!] Skip recognition level: {} ({})". \
|
||||
format(hparams.ignore_recognition_level,
|
||||
ignore_description[hparams.ignore_recognition_level]))
|
||||
|
||||
for audio_path, text in info.items():
|
||||
if hparams.ignore_recognition_level > 0 and loss_coeff[audio_path] != 1:
|
||||
continue
|
||||
|
||||
if base_dir not in audio_path:
|
||||
audio_path = os.path.join(base_dir, audio_path)
|
||||
|
||||
try:
|
||||
tokens = text_to_sequence(text)
|
||||
except:
|
||||
continue
|
||||
|
||||
fn = partial(
|
||||
_process_utterance,
|
||||
audio_path, data_dir, tokens, loss_coeff[audio_path])
|
||||
futures.append(executor.submit(fn))
|
||||
|
||||
n_frames = [future.result() for future in tqdm(futures)]
|
||||
n_frames = [n_frame for n_frame in n_frames if n_frame is not None]
|
||||
|
||||
hours = frames_to_hours(n_frames)
|
||||
|
||||
print(' [*] Loaded metadata for {} examples ({:.2f} hours)'.format(len(n_frames), hours))
|
||||
print(' [*] Max length: {}'.format(max(n_frames)))
|
||||
print(' [*] Min length: {}'.format(min(n_frames)))
|
||||
|
||||
plot_n_frames(n_frames, os.path.join(
|
||||
base_dir, "n_frames_before_filter.png"))
|
||||
|
||||
min_n_frame = hparams.reduction_factor * hparams.min_iters
|
||||
max_n_frame = hparams.reduction_factor * hparams.max_iters - hparams.reduction_factor
|
||||
|
||||
n_frames = [n for n in n_frames if min_n_frame <= n <= max_n_frame]
|
||||
hours = frames_to_hours(n_frames)
|
||||
|
||||
print(' [*] After filtered: {} examples ({:.2f} hours)'.format(len(n_frames), hours))
|
||||
print(' [*] Max length: {}'.format(max(n_frames)))
|
||||
print(' [*] Min length: {}'.format(min(n_frames)))
|
||||
|
||||
plot_n_frames(n_frames, os.path.join(
|
||||
base_dir, "n_frames_after_filter.png"))
|
||||
|
||||
def plot_n_frames(n_frames, path):
|
||||
labels, values = list(zip(*Counter(n_frames).most_common()))
|
||||
|
||||
values = [v for _, v in sorted(zip(labels, values))]
|
||||
labels = sorted(labels)
|
||||
|
||||
indexes = np.arange(len(labels))
|
||||
width = 1
|
||||
|
||||
fig, ax = plt.subplots(figsize=(len(labels) / 2, 5))
|
||||
|
||||
plt.bar(indexes, values, width)
|
||||
plt.xticks(indexes + width * 0.5, labels)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path)
|
||||
|
||||
|
||||
def _process_utterance(audio_path, data_dir, tokens, loss_coeff):
|
||||
audio_name = os.path.basename(audio_path)
|
||||
|
||||
filename = audio_name.rsplit('.', 1)[0] + ".npz"
|
||||
numpy_path = os.path.join(data_dir, filename)
|
||||
|
||||
if not os.path.exists(numpy_path):
|
||||
wav = load_audio(audio_path)
|
||||
|
||||
linear_spectrogram = spectrogram(wav).astype(np.float32)
|
||||
mel_spectrogram = melspectrogram(wav).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"linear": linear_spectrogram.T,
|
||||
"mel": mel_spectrogram.T,
|
||||
"tokens": tokens,
|
||||
"loss_coeff": loss_coeff,
|
||||
}
|
||||
|
||||
n_frame = linear_spectrogram.shape[1]
|
||||
|
||||
if hparams.skip_inadequate:
|
||||
min_n_frame = hparams.reduction_factor * hparams.min_iters
|
||||
max_n_frame = hparams.reduction_factor * hparams.max_iters - hparams.reduction_factor
|
||||
|
||||
if min_n_frame <= n_frame <= max_n_frame and len(tokens) >= hparams.min_tokens:
|
||||
return None
|
||||
|
||||
np.savez(numpy_path, **data, allow_pickle=False)
|
||||
else:
|
||||
try:
|
||||
data = np.load(numpy_path)
|
||||
n_frame = data["linear"].shape[0]
|
||||
except:
|
||||
remove_file(numpy_path)
|
||||
return _process_utterance(audio_path, data_dir, tokens, loss_coeff)
|
||||
|
||||
return n_frame
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='spectrogram')
|
||||
|
||||
parser.add_argument('metadata_path', type=str)
|
||||
parser.add_argument('--data_dirname', type=str, default="data")
|
||||
parser.add_argument('--num_workers', type=int, default=None)
|
||||
|
||||
config = parser.parse_args()
|
||||
build_from_path(config)
|
59
datasets/moon/download.py
Normal file
59
datasets/moon/download.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
import youtube_dl
|
||||
from pydub import AudioSegment
|
||||
|
||||
from utils import makedirs, remove_file
|
||||
|
||||
|
||||
base_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
def get_mili_sec(text):
|
||||
minute, second = text.strip().split(':')
|
||||
return (int(minute) * 60 + int(second)) * 1000
|
||||
|
||||
class Data(object):
|
||||
def __init__(
|
||||
self, text_path, video_url, title, start_time, end_time):
|
||||
self.text_path = text_path
|
||||
self.video_url = video_url
|
||||
self.title = title
|
||||
self.start = get_mili_sec(start_time)
|
||||
self.end = get_mili_sec(end_time)
|
||||
|
||||
def read_csv(path):
|
||||
with open(path) as f:
|
||||
data = []
|
||||
for line in f:
|
||||
text_path, video_url, title, start_time, end_time = line.split('|')
|
||||
data.append(Data(text_path, video_url, title, start_time, end_time))
|
||||
return data
|
||||
|
||||
def download_audio_with_urls(data, out_ext="wav"):
|
||||
for d in data:
|
||||
original_path = os.path.join(base_dir, 'audio',
|
||||
os.path.basename(d.text_path)).replace('.txt', '.original.mp3')
|
||||
out_path = os.path.join(base_dir, 'audio',
|
||||
os.path.basename(d.text_path)).replace('.txt', '.wav')
|
||||
|
||||
options = {
|
||||
'format': 'bestaudio/best',
|
||||
'outtmpl': original_path,
|
||||
'postprocessors': [{
|
||||
'key': 'FFmpegExtractAudio',
|
||||
'preferredcodec': 'mp3',
|
||||
'preferredquality': '320',
|
||||
}],
|
||||
}
|
||||
with youtube_dl.YoutubeDL(options) as ydl:
|
||||
ydl.download([d.video_url])
|
||||
|
||||
audio = AudioSegment.from_file(original_path)
|
||||
audio[d.start:d.end].export(out_path, out_ext)
|
||||
|
||||
remove_file(original_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
makedirs(os.path.join(base_dir, "audio"))
|
||||
|
||||
data = read_csv(os.path.join(base_dir, "metadata.csv"))
|
||||
download_audio_with_urls(data)
|
12
datasets/moon/metadata.csv
Normal file
12
datasets/moon/metadata.csv
Normal file
|
@ -0,0 +1,12 @@
|
|||
assets/001.txt|https://www.youtube.com/watch?v=_YWqWHe8LwE|국회 시정연설|0:56|30:05
|
||||
assets/002.txt|https://www.youtube.com/watch?v=p0iokDQy1sQ|유엔총회 기조연설|0:00|21:55
|
||||
assets/003.txt|https://www.youtube.com/watch?v=eU4xI0OR9yQ|베를린 한반도 평화구상 연설|0:00|25:06
|
||||
assets/004.txt|https://www.youtube.com/watch?v=PQXSzswJDyU|동방경제포럼 기조연설|0:00|17:58
|
||||
assets/005.txt|https://www.youtube.com/watch?v=dOYaWLddRbU|취임사1|0:01|0:37
|
||||
assets/006.txt|https://www.youtube.com/watch?v=dOYaWLddRbU|취임사2|1:09|12:45
|
||||
assets/007.txt|https://www.youtube.com/watch?v=05yqIiwpqGw|6·15 남북공동선언 17주년 기념식 축사|0:05|12:14
|
||||
assets/008.txt|https://www.youtube.com/watch?v=etwb4AR5hg4|현충일 추념사|0:00|12:05
|
||||
assets/009.txt|https://www.youtube.com/watch?v=TGZeC52r8WM|바다의 날 기념사|0:00|12:20
|
||||
assets/010.txt|https://www.youtube.com/watch?v=T2ANoBtp1p8|제72주년 광복절 경축사|0:13|29:26
|
||||
assets/011.txt|https://www.youtube.com/watch?v=HRCTTRWAbNA|남북정상회담 17주년 기념식 축사|0:07|12:03
|
||||
assets/012.txt|https://www.youtube.com/watch?v=Md5219iWdbs|2차 AIIB 연차총회 개회식 축사|0:07|15:55
|
|
4879
datasets/moon/recognition.json
Normal file
4879
datasets/moon/recognition.json
Normal file
File diff suppressed because it is too large
Load diff
59
datasets/park/download.py
Normal file
59
datasets/park/download.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
import youtube_dl
|
||||
from pydub import AudioSegment
|
||||
|
||||
from utils import makedirs, remove_file
|
||||
|
||||
|
||||
base_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
def get_mili_sec(text):
|
||||
minute, second = text.strip().split(':')
|
||||
return (int(minute) * 60 + int(second)) * 1000
|
||||
|
||||
class Data(object):
|
||||
def __init__(
|
||||
self, text_path, video_url, title, start_time, end_time):
|
||||
self.text_path = text_path
|
||||
self.video_url = video_url
|
||||
self.title = title
|
||||
self.start = get_mili_sec(start_time)
|
||||
self.end = get_mili_sec(end_time)
|
||||
|
||||
def read_csv(path):
|
||||
with open(path) as f:
|
||||
data = []
|
||||
for line in f:
|
||||
text_path, video_url, title, start_time, end_time = line.split('|')
|
||||
data.append(Data(text_path, video_url, title, start_time, end_time))
|
||||
return data
|
||||
|
||||
def download_audio_with_urls(data, out_ext="wav"):
|
||||
for d in data:
|
||||
original_path = os.path.join(base_dir, 'audio',
|
||||
os.path.basename(d.text_path)).replace('.txt', '.original.mp3')
|
||||
out_path = os.path.join(base_dir, 'audio',
|
||||
os.path.basename(d.text_path)).replace('.txt', '.wav')
|
||||
|
||||
options = {
|
||||
'format': 'bestaudio/best',
|
||||
'outtmpl': original_path,
|
||||
'postprocessors': [{
|
||||
'key': 'FFmpegExtractAudio',
|
||||
'preferredcodec': 'mp3',
|
||||
'preferredquality': '320',
|
||||
}],
|
||||
}
|
||||
with youtube_dl.YoutubeDL(options) as ydl:
|
||||
ydl.download([d.video_url])
|
||||
|
||||
audio = AudioSegment.from_file(original_path)
|
||||
audio[d.start:d.end].export(out_path, out_ext)
|
||||
|
||||
remove_file(original_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
makedirs(os.path.join(base_dir, "audio"))
|
||||
|
||||
data = read_csv(os.path.join(base_dir, "metadata.csv"))
|
||||
download_audio_with_urls(data)
|
21
datasets/park/metadata.csv
Normal file
21
datasets/park/metadata.csv
Normal file
|
@ -0,0 +1,21 @@
|
|||
assets/001.txt|https://www.youtube.com/watch?v=jn_Re6tW5Uo|개성공단 국회 연설|0:04|26:00
|
||||
assets/002.txt|https://www.youtube.com/watch?v=56WKAcps8uM|2015 대국민 담화|0:05|24:13
|
||||
assets/003.txt|https://www.youtube.com/watch?v=_Fym6railzc|2016 대국민 담화|0:11|30:47
|
||||
assets/004.txt|https://www.youtube.com/watch?v=vBYXDJkW5eY|제97주년 3ㆍ1절 기념축사|0:01|18:36
|
||||
assets/005.txt|https://www.youtube.com/watch?v=__37IbJeb4I|건군 68주년 국군의 날 기념식 축사|0:00|16:36
|
||||
assets/006.txt|https://www.youtube.com/watch?v=A_Fyx2wZB30|최순실 사건 대국민 담화 발표|0:02|8:50
|
||||
assets/007.txt|https://www.youtube.com/watch?v=8eKgE5sRsko|2016 현충일 추념사|0:00|8:28
|
||||
assets/008.txt|https://www.youtube.com/watch?v=xbrMCJn4OfQ|2014 현충일 추념사|0:00|6:45
|
||||
assets/009.txt|https://www.youtube.com/watch?v=ONBO3A6YGw8|제70차 유엔총회 기조연설|0:21|23:03
|
||||
assets/010.txt|https://www.youtube.com/watch?v=rl1lTwD5-CU|2014 신년 기자회견|0:05|99:00
|
||||
assets/011.txt|https://www.youtube.com/watch?v=iI-K6B3u-a8|2016 서해 수호의 날 기념사|0:09|8:13
|
||||
assets/012.txt|https://www.youtube.com/watch?v=SuOJEZMPGqE|연설문 사전 유출 대국민사과|0:26|1:48
|
||||
assets/013.txt|https://www.youtube.com/watch?v=BVQMycTnmAU|2017 예산안 설명 국회 시정연설|0:48|36:43
|
||||
assets/014.txt|https://www.youtube.com/watch?v=-buLcCLNeTA|2016 20대 국회 개원 연설|0:00|27:32
|
||||
assets/015.txt|https://www.youtube.com/watch?v=5G4o-v8QfFw|2014 독일 드레스덴 연설|0:00|22:29
|
||||
assets/016.txt|https://www.youtube.com/watch?v=qczKAq9gA-k|70주년 광복절 경축사|0:09|25:33
|
||||
assets/017.txt|https://www.youtube.com/watch?v=T_29pBDIfDQ|71주년 광복절 경축사|0:06|26:27
|
||||
assets/018.txt|https://www.youtube.com/watch?v=P9Rf1ERW7pE|아프리카연합(AU) 특별연설|0:07|20:04
|
||||
assets/019.txt|https://www.youtube.com/watch?v=P7K9oVBdqe0|2014 예산안 시정연설|0:01|35:14
|
||||
assets/020.txt|https://www.youtube.com/watch?v=Enuo-yOjT9M|2013 예산안 시정연설|0:00|29:15
|
||||
assets/021.txt|https://www.youtube.com/watch?v=GYHtSjMi3DU|69주년 광복절 경축사|0:00|24:19
|
|
0
datasets/park/process.py
Normal file
0
datasets/park/process.py
Normal file
8776
datasets/park/recognition.json
Normal file
8776
datasets/park/recognition.json
Normal file
File diff suppressed because it is too large
Load diff
150
datasets/son/download.py
Normal file
150
datasets/son/download.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
import re
|
||||
import os
|
||||
import sys
|
||||
import m3u8
|
||||
import json
|
||||
import requests
|
||||
import subprocess
|
||||
from functools import partial
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from utils import get_encoder_name, parallel_run, makedirs
|
||||
|
||||
API_URL = 'http://api.jtbc.joins.com/ad/pre/NV10173083'
|
||||
BASE_URL = 'http://nsvc.jtbc.joins.com/API/News/Newapp/Default.aspx'
|
||||
|
||||
def soupify(text):
|
||||
return BeautifulSoup(text, "html.parser")
|
||||
|
||||
def get_news_ids(page_id):
|
||||
params = {
|
||||
'NJC': 'NJC300',
|
||||
'CAID': 'NC10011174',
|
||||
'PGI': page_id,
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
method='GET', url=BASE_URL, params=params,
|
||||
)
|
||||
soup = soupify(response.text)
|
||||
|
||||
return [item.text for item in soup.find_all('news_id')]
|
||||
|
||||
def download_news_video_and_content(
|
||||
news_id, base_dir, chunk_size=32*1024,
|
||||
video_dir="video", asset_dir="assets", audio_dir="audio"):
|
||||
|
||||
video_dir = os.path.join(base_dir, video_dir)
|
||||
asset_dir = os.path.join(base_dir, asset_dir)
|
||||
audio_dir = os.path.join(base_dir, audio_dir)
|
||||
|
||||
makedirs(video_dir)
|
||||
makedirs(asset_dir)
|
||||
makedirs(audio_dir)
|
||||
|
||||
text_path = os.path.join(asset_dir, "{}.txt".format(news_id))
|
||||
original_text_path = os.path.join(asset_dir, "original-{}.txt".format(news_id))
|
||||
|
||||
video_path = os.path.join(video_dir, "{}.ts".format(news_id))
|
||||
audio_path = os.path.join(audio_dir, "{}.wav".format(news_id))
|
||||
|
||||
params = {
|
||||
'NJC': 'NJC400',
|
||||
'NID': news_id, # NB11515152
|
||||
'CD': 'A0100',
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
method='GET', url=BASE_URL, params=params,
|
||||
)
|
||||
soup = soupify(response.text)
|
||||
|
||||
article_contents = soup.find_all('article_contents')
|
||||
|
||||
assert len(article_contents) == 1, \
|
||||
"# of <article_contents> of {} should be 1: {}".format(news_id, response.text)
|
||||
|
||||
text = soupify(article_contents[0].text).get_text() # remove <div>
|
||||
|
||||
with open(original_text_path, "w") as f:
|
||||
f.write(text)
|
||||
|
||||
with open(text_path, "w") as f:
|
||||
from nltk import sent_tokenize
|
||||
|
||||
text = re.sub(r'\[.{0,80} :\s.+]', '', text) # remove quote
|
||||
text = re.sub(r'☞.+http.+\)', '', text) # remove quote
|
||||
text = re.sub(r'\(https?:\/\/.*[\r\n]*\)', '', text) # remove url
|
||||
|
||||
sentences = sent_tokenize(text)
|
||||
sentences = [sent for sentence in sentences for sent in sentence.split('\n') if sent]
|
||||
|
||||
new_texts = []
|
||||
for sent in sentences:
|
||||
sent = sent.strip()
|
||||
sent = re.sub(r'\([^)]*\)', '', sent)
|
||||
#sent = re.sub(r'\<.{0,80}\>', '', sent)
|
||||
sent = sent.replace('…', '.')
|
||||
new_texts.append(sent)
|
||||
|
||||
f.write("\n".join([sent for sent in new_texts if sent]))
|
||||
|
||||
vod_paths = soup.find_all('vod_path')
|
||||
|
||||
assert len(vod_paths) == 1, \
|
||||
"# of <vod_path> of {} should be 1: {}".format(news_id, response.text)
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
redirect_url = soup.find_all('vod_path')[0].text
|
||||
|
||||
list_url = m3u8.load(redirect_url).playlists[0].absolute_uri
|
||||
video_urls = [segment.absolute_uri for segment in m3u8.load(list_url).segments]
|
||||
|
||||
with open(video_path, "wb") as f:
|
||||
for url in video_urls:
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
for chunk in response.iter_content(chunk_size):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
encoder = get_encoder_name()
|
||||
command = "{} -y -loglevel panic -i {} -ab 160k -ac 2 -ar 44100 -vn {}".\
|
||||
format(encoder, video_path, audio_path)
|
||||
subprocess.call(command, shell=True)
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
news_ids = []
|
||||
page_idx = 1
|
||||
|
||||
base_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
news_id_path = os.path.join(base_dir, "news_ids.json")
|
||||
|
||||
if not os.path.exists(news_id_path):
|
||||
while True:
|
||||
tmp_ids = get_news_ids(page_idx)
|
||||
if len(tmp_ids) == 0:
|
||||
break
|
||||
|
||||
news_ids.extend(tmp_ids)
|
||||
print(" [*] Download page {}: {}/{}".format(page_idx, len(tmp_ids), len(news_ids)))
|
||||
|
||||
page_idx += 1
|
||||
|
||||
with open(news_id_path, "w") as f:
|
||||
json.dump(news_ids, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
with open(news_id_path) as f:
|
||||
news_ids = json.loads(f.read())
|
||||
|
||||
exceptions = ["NB10830162"]
|
||||
news_ids = list(set(news_ids) - set(exceptions))
|
||||
|
||||
fn = partial(download_news_video_and_content, base_dir=base_dir)
|
||||
|
||||
results = parallel_run(
|
||||
fn, news_ids, desc="Download news video+text", parallel=True)
|
59
datasets/yuinna/download.py
Normal file
59
datasets/yuinna/download.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from functools import partial
|
||||
|
||||
from utils import download_with_url, makedirs, parallel_run
|
||||
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
RSS_URL = "http://enabler.kbs.co.kr/api/podcast_channel/feed.xml?channel_id=R2010-0440"
|
||||
|
||||
def itunes_download(item):
|
||||
audio_dir = os.path.join(base_path, "audio")
|
||||
|
||||
date, url = item
|
||||
path = os.path.join(audio_dir, "{}.mp4".format(date))
|
||||
|
||||
if not os.path.exists(path):
|
||||
download_with_url(url, path)
|
||||
|
||||
def download_all(config):
|
||||
audio_dir = os.path.join(base_path, "audio")
|
||||
makedirs(audio_dir)
|
||||
|
||||
soup = BeautifulSoup(requests.get(RSS_URL).text, "html5lib")
|
||||
|
||||
items = [item for item in soup.find_all('item')]
|
||||
|
||||
titles = [item.find('title').text[9:-3] for item in items]
|
||||
guids = [item.find('guid').text for item in items]
|
||||
|
||||
accept_list = ['친절한 인나씨', '반납예정일', '귀욤열매 드세요']
|
||||
|
||||
new_guids = [guid for title, guid in zip(titles, guids) \
|
||||
if any(accept in title for accept in accept_list) and '-' not in title]
|
||||
new_titles = [title for title, _ in zip(titles, guids) \
|
||||
if any(accept in title for accept in accept_list) and '-' not in title]
|
||||
|
||||
for idx, title in enumerate(new_titles):
|
||||
print(" [{:3d}] {}, {}".format(idx + 1, title,
|
||||
os.path.basename(new_guids[idx]).split('_')[2]))
|
||||
if idx == config.max_num: print("="*30)
|
||||
|
||||
urls = {
|
||||
os.path.basename(guid).split('_')[2]: guid \
|
||||
for guid in new_guids[:config.max_num]
|
||||
}
|
||||
|
||||
parallel_run(itunes_download, urls.items(),
|
||||
desc=" [*] Itunes download", parallel=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_num', default=100, type=int)
|
||||
config, unparsed = parser.parse_known_args()
|
||||
|
||||
download_all(config)
|
29165
datasets/yuinna/recognition.json
Normal file
29165
datasets/yuinna/recognition.json
Normal file
File diff suppressed because it is too large
Load diff
122
download.py
Normal file
122
download.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
# Code based on https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import gzip
|
||||
import json
|
||||
import tarfile
|
||||
import zipfile
|
||||
import argparse
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from six.moves import urllib
|
||||
|
||||
from utils import query_yes_no
|
||||
|
||||
parser = argparse.ArgumentParser(description='Download model checkpoints.')
|
||||
parser.add_argument('checkpoints', metavar='N', type=str, nargs='+', choices=['son', 'park'],
|
||||
help='name of checkpoints to download [son, park]')
|
||||
|
||||
def download(url, dirpath):
|
||||
filename = url.split('/')[-1]
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
u = urllib.request.urlopen(url)
|
||||
f = open(filepath, 'wb')
|
||||
filesize = int(u.headers["Content-Length"])
|
||||
print("Downloading: %s Bytes: %s" % (filename, filesize))
|
||||
|
||||
downloaded = 0
|
||||
block_sz = 8192
|
||||
status_width = 70
|
||||
while True:
|
||||
buf = u.read(block_sz)
|
||||
if not buf:
|
||||
print('')
|
||||
break
|
||||
else:
|
||||
print('', end='\r')
|
||||
downloaded += len(buf)
|
||||
f.write(buf)
|
||||
status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
|
||||
('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
|
||||
print(status, end='')
|
||||
sys.stdout.flush()
|
||||
f.close()
|
||||
return filepath
|
||||
|
||||
def download_file_from_google_drive(id, destination):
|
||||
URL = "https://docs.google.com/uc?export=download"
|
||||
session = requests.Session()
|
||||
|
||||
response = session.get(URL, params={ 'id': id }, stream=True)
|
||||
token = get_confirm_token(response)
|
||||
|
||||
if token:
|
||||
params = { 'id' : id, 'confirm' : token }
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
save_response_content(response, destination)
|
||||
|
||||
def get_confirm_token(response):
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith('download_warning'):
|
||||
return value
|
||||
return None
|
||||
|
||||
def save_response_content(response, destination, chunk_size=32*1024):
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
with open(destination, "wb") as f:
|
||||
for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
|
||||
unit='B', unit_scale=True, desc=destination):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
|
||||
def unzip(filepath):
|
||||
print("Extracting: " + filepath)
|
||||
dirpath = os.path.dirname(filepath)
|
||||
with zipfile.ZipFile(filepath) as zf:
|
||||
zf.extractall(dirpath)
|
||||
os.remove(filepath)
|
||||
|
||||
def download_checkpoint(checkpoint):
|
||||
if checkpoint == "son":
|
||||
save_path, drive_id = "son-20171015.tar.gz", "0B_7wC-DuR6ORcmpaY1A5V1AzZUU"
|
||||
elif checkpoint == "park":
|
||||
save_path, drive_id = "park-20171015.tar.gz", "0B_7wC-DuR6ORYjhlekl5bVlkQ2c"
|
||||
else:
|
||||
raise Exception(" [!] Unknown checkpoint: {}".format(checkpoint))
|
||||
|
||||
if os.path.exists(save_path):
|
||||
print('[*] {} already exists'.format(save_path))
|
||||
else:
|
||||
download_file_from_google_drive(drive_id, save_path)
|
||||
|
||||
if save_path.endswith(".zip"):
|
||||
zip_dir = ''
|
||||
with zipfile.ZipFile(save_path) as zf:
|
||||
zip_dir = zf.namelist()[0]
|
||||
zf.extractall(dirpath)
|
||||
os.remove(save_path)
|
||||
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
|
||||
elif save_path.endswith("tar.gz"):
|
||||
tar = tarfile.open(save_path, "r:gz")
|
||||
tar.extractall()
|
||||
tar.close()
|
||||
elif save_path.endswith("tar"):
|
||||
tar = tarfile.open(save_path, "r:")
|
||||
tar.extractall()
|
||||
tar.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
print(" [!] The pre-trained models are being made available for research purpose only")
|
||||
print(" [!] 학습된 모델을 연구 이외의 목적으로 사용하는 것을 금지합니다.")
|
||||
print()
|
||||
|
||||
if query_yes_no(" [?] Are you agree on this? 이에 동의하십니까?"):
|
||||
if 'park' in args.checkpoints:
|
||||
download_checkpoint('park')
|
||||
if 'son' in args.checkpoints:
|
||||
download_checkpoint('son')
|
138
eval.py
Normal file
138
eval.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
import os
|
||||
import re
|
||||
import math
|
||||
import argparse
|
||||
from glob import glob
|
||||
|
||||
from synthesizer import Synthesizer
|
||||
from train import create_batch_inputs_from_texts
|
||||
from utils import makedirs, str2bool, backup_file
|
||||
from hparams import hparams, hparams_debug_string
|
||||
|
||||
|
||||
texts = [
|
||||
'텍스트를 음성으로 읽어주는 "음성합성" 기술은 시각 장애인을 위한 오디오북, 음성 안내 시스템, 대화 인공지능 등 많은 분야에 활용할 수 있습니다.',
|
||||
"하지만 개인이 원하는 목소리로 음성합성 엔진을 만들기에는 어려움이 많았고 소수의 기업만이 기술을 보유하고 있었습니다.",
|
||||
"최근 딥러닝 기술의 발전은 음성합성 기술의 진입 장벽을 많이 낮췄고 이제는 누구나 손쉽게 음성합성 엔진을 만들 수 있게 되었습니다.",
|
||||
|
||||
"본 세션에서는 딥러닝을 활용한 음성합성 기술을 소개하고 개발 경험과 그 과정에서 얻었던 팁을 공유하고자 합니다.",
|
||||
"음성합성 엔진을 구현하는데 사용한 세 가지 연구를 소개하고 각각의 기술이 얼마나 자연스러운 목소리를 만들 수 있는지를 공유합니다.",
|
||||
|
||||
# Harry Potter
|
||||
"그리고 헤르미온느는 겁에 질려 마룻바닥에 쓰러져 있었다.",
|
||||
"그러자 론은 요술지팡이를 꺼냈다. 무엇을 할지도 모르면서 그는 머리에 처음으로 떠오른 주문을 외치고 있었다.",
|
||||
"윙가르디움 레비오우사.... 하지만, 그렇게 소리쳤다.",
|
||||
"그러자 그 방망이가 갑자기 트롤의 손에서 벗어나, 저 위로 올라가더니 탁하며 그 주인의 머리 위에 떨어졌다.",
|
||||
"그러자 트롤이 그 자리에서 비틀거리더니 방 전체를 흔들어버릴 것 같은 커다란 소리를 내며 쿵 하고 넘어졌다. ",
|
||||
"그러자 조그맣게 펑 하는 소리가 나면서 가장 가까이 있는 가로등이 꺼졌다.",
|
||||
"그리고 그가 다시 찰깍하자 그 다음 가로등이 깜박거리며 나가 버렸다.",
|
||||
|
||||
#"그가 그렇게 가로등 끄기를 열두번 하자, 이제 그 거리에 남아 있는 불빛이라곤, ",
|
||||
#"바늘로 꼭 질러둔 것처럼 작게 보이는 멀리서 그를 지켜보고 있는 고양이의 두 눈뿐이었다.",
|
||||
#"프리벳가 4번지에 살고 있는 더즐리 부부는 자신들이 정상적이라는 것을 아주 자랑스럽게 여기는 사람들이었다. ",
|
||||
#"그들은 기이하거나 신비스런 일과는 전혀 무관해 보였다.",
|
||||
#"아니, 그런 터무니없는 것은 도저히 참아내지 못했다.",
|
||||
#"더즐리 씨는 그루닝스라는 드릴제작회사의 중역이었다.",
|
||||
#"그는 목이 거의 없을 정도로 살이 뒤룩뒤룩 찐 몸집이 큰 사내로, 코밑에는 커다란 콧수염을 기르고 있었다.",
|
||||
#"더즐리 부인은 마른 체구의 금발이었고, 목이 보통사람보다 두 배는 길어서, 담 너머로 고개를 쭉 배고 이웃 사람들을 몰래 훔쳐보는 그녀의 취미에는 더없이 제격이었다.",
|
||||
|
||||
# From Yoo Inna's Audiobook (http://campaign.happybean.naver.com/yooinna_audiobook):
|
||||
#'16세기 중엽 어느 가을날 옛 런던 시의 가난한 캔티 집안에 사내아이 하나가 태어났다.',
|
||||
#'그런데 그 집안에서는 그 사내아이를 별로 반기지 않았다.',
|
||||
#'바로 같은 날 또 한 명의 사내아이가 영국의 부유한 튜터 가문에서 태어났다.',
|
||||
#'그런데 그 가문에서는 그 아이를 무척이나 반겼다.',
|
||||
#'온 영국이 다 함께 그 아이를 반겼다.',
|
||||
|
||||
## From NAVER's Audiobook (http://campaign.happybean.naver.com/yooinna_audiobook):
|
||||
#'부랑자 패거리는 이른 새벽에 일찍 출발하여 길을 떠났다.',
|
||||
#'하늘은 찌푸렸고, 발밑의 땅은 질퍽거렸으며, 겨울의 냉기가 공기 중에 감돌았다.',
|
||||
#'지난밤의 흥겨움은 온데간데없이 사라졌다.',
|
||||
#'시무룩하게 말이 없는 사람들도 있었고, 안달복달하며 조바심을 내는 사람들도 있었지만, 기분이 좋은 사람은 하나도 없었다.',
|
||||
|
||||
## From NAVER's nVoice example (https://www.facebook.com/naverlabs/videos/422780217913446):
|
||||
#'감사합니다. Devsisters 김태훈 님의 발표였습니다.',
|
||||
#'이것으로 금일 마련된 track 2의 모든 세션이 종료되었습니다.',
|
||||
#'장시간 끝까지 참석해주신 개발자 여러분들께 진심으로 감사의 말씀을 드리며,',
|
||||
#'잠시 후 5시 15분부터 특정 주제에 관심 있는 사람들이 모여 자유롭게 이야기하는 오프미팅이 진행될 예정이므로',
|
||||
#'참여신청을 해주신 분들은 진행 요원의 안내에 따라 이동해주시기 바랍니다.',
|
||||
|
||||
## From Kakao's Son Seok hee example (https://www.youtube.com/watch?v=ScfdAH2otrY):
|
||||
#'소설가 마크 트웨인이 말했습니다.',
|
||||
#'인생에 가장 중요한 이틀이 있는데, 하나는 세상에 태어난 날이고 다른 하나는 왜 이 세상에 왔는가를 깨닫는 날이다.',
|
||||
#'그런데 그 첫번째 날은 누구나 다 알지만 두번째 날은 참 어려운 것 같습니다.',
|
||||
#'누구나 그 두번째 날을 만나기 위해 애쓰는게 삶인지도 모르겠습니다.',
|
||||
#'뉴스룸도 그런 면에서 똑같습니다.',
|
||||
#'저희들도 그 두번째의 날을 만나고 기억하기 위해 매일 매일 최선을 다하겠습니다.',
|
||||
]
|
||||
|
||||
|
||||
def get_output_base_path(load_path, eval_dirname="eval"):
|
||||
if not os.path.isdir(load_path):
|
||||
base_dir = os.path.dirname(load_path)
|
||||
else:
|
||||
base_dir = load_path
|
||||
|
||||
base_dir = os.path.join(base_dir, eval_dirname)
|
||||
if os.path.exists(base_dir):
|
||||
backup_file(base_dir)
|
||||
makedirs(base_dir)
|
||||
|
||||
m = re.compile(r'.*?\.ckpt\-([0-9]+)').match(load_path)
|
||||
base_path = os.path.join(base_dir,
|
||||
'eval-%d' % int(m.group(1)) if m else 'eval')
|
||||
return base_path
|
||||
|
||||
|
||||
def run_eval(args):
|
||||
print(hparams_debug_string())
|
||||
|
||||
load_paths = glob(args.load_path_pattern)
|
||||
|
||||
for load_path in load_paths:
|
||||
if not os.path.exists(os.path.join(load_path, "checkpoint")):
|
||||
print(" [!] Skip non model directory: {}".format(load_path))
|
||||
continue
|
||||
|
||||
synth = Synthesizer()
|
||||
synth.load(load_path)
|
||||
|
||||
for speaker_id in range(synth.num_speakers):
|
||||
base_path = get_output_base_path(load_path, "eval-{}".format(speaker_id))
|
||||
|
||||
inputs, input_lengths = create_batch_inputs_from_texts(texts)
|
||||
|
||||
for idx in range(math.ceil(len(inputs) / args.batch_size)):
|
||||
start_idx, end_idx = idx*args.batch_size, (idx+1)*args.batch_size
|
||||
|
||||
cur_texts = texts[start_idx:end_idx]
|
||||
cur_inputs = inputs[start_idx:end_idx]
|
||||
|
||||
synth.synthesize(
|
||||
texts=cur_texts,
|
||||
speaker_ids=[speaker_id] * len(cur_texts),
|
||||
tokens=cur_inputs,
|
||||
base_path="{}-{}".format(base_path, idx),
|
||||
manual_attention_mode=args.manual_attention_mode,
|
||||
base_alignment_path=args.base_alignment_path,
|
||||
)
|
||||
|
||||
synth.close()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--batch_size', default=16)
|
||||
parser.add_argument('--load_path_pattern', required=True)
|
||||
parser.add_argument('--base_alignment_path', default=None)
|
||||
parser.add_argument('--manual_attention_mode', default=0, type=int,
|
||||
help="0: None, 1: Argmax, 2: Sharpening, 3. Pruning")
|
||||
parser.add_argument('--hparams', default='',
|
||||
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
|
||||
args = parser.parse_args()
|
||||
|
||||
#hparams.max_iters = 100
|
||||
#hparams.parse(args.hparams)
|
||||
run_eval(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
156
hparams.py
Normal file
156
hparams.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import tensorflow as tf
|
||||
|
||||
SCALE_FACTOR = 1
|
||||
|
||||
def f(num):
|
||||
return num // SCALE_FACTOR
|
||||
|
||||
basic_params = {
|
||||
# Comma-separated list of cleaners to run on text prior to training and eval. For non-English
|
||||
# text, you may want to use "basic_cleaners" or "transliteration_cleaners" See TRAINING_DATA.md.
|
||||
'cleaners': 'korean_cleaners',
|
||||
}
|
||||
|
||||
basic_params.update({
|
||||
# Audio
|
||||
'num_mels': 80,
|
||||
'num_freq': 1025,
|
||||
'sample_rate': 20000,
|
||||
'frame_length_ms': 50,
|
||||
'frame_shift_ms': 12.5,
|
||||
'preemphasis': 0.97,
|
||||
'min_level_db': -100,
|
||||
'ref_level_db': 20,
|
||||
})
|
||||
|
||||
if True:
|
||||
basic_params.update({
|
||||
'sample_rate': 24000,
|
||||
})
|
||||
|
||||
basic_params.update({
|
||||
# Model
|
||||
'model_type': 'single', # [single, simple, deepvoice]
|
||||
'speaker_embedding_size': f(16),
|
||||
|
||||
'embedding_size': f(256),
|
||||
'dropout_prob': 0.5,
|
||||
|
||||
# Encoder
|
||||
'enc_prenet_sizes': [f(256), f(128)],
|
||||
'enc_bank_size': 16,
|
||||
'enc_bank_channel_size': f(128),
|
||||
'enc_maxpool_width': 2,
|
||||
'enc_highway_depth': 4,
|
||||
'enc_rnn_size': f(128),
|
||||
'enc_proj_sizes': [f(128), f(128)],
|
||||
'enc_proj_width': 3,
|
||||
|
||||
# Attention
|
||||
'attention_type': 'bah_mon', # ntm2-5
|
||||
'attention_size': f(256),
|
||||
'attention_state_size': f(256),
|
||||
|
||||
# Decoder recurrent network
|
||||
'dec_layer_num': 2,
|
||||
'dec_rnn_size': f(256),
|
||||
|
||||
# Decoder
|
||||
'dec_prenet_sizes': [f(256), f(128)],
|
||||
'post_bank_size': 8,
|
||||
'post_bank_channel_size': f(256),
|
||||
'post_maxpool_width': 2,
|
||||
'post_highway_depth': 4,
|
||||
'post_rnn_size': f(128),
|
||||
'post_proj_sizes': [f(256), 80], # num_mels=80
|
||||
'post_proj_width': 3,
|
||||
|
||||
'reduction_factor': 4,
|
||||
})
|
||||
|
||||
if False: # Deep Voice 2
|
||||
basic_params.update({
|
||||
'dropout_prob': 0.8,
|
||||
|
||||
'attention_size': f(512),
|
||||
|
||||
'dec_prenet_sizes': [f(256), f(128), f(64)],
|
||||
'post_bank_channel_size': f(512),
|
||||
'post_rnn_size': f(256),
|
||||
|
||||
'reduction_factor': 4,
|
||||
})
|
||||
elif True: # Deep Voice 2
|
||||
basic_params.update({
|
||||
'dropout_prob': 0.8,
|
||||
|
||||
#'attention_size': f(512),
|
||||
|
||||
#'dec_prenet_sizes': [f(256), f(128)],
|
||||
#'post_bank_channel_size': f(512),
|
||||
'post_rnn_size': f(256),
|
||||
|
||||
'reduction_factor': 4,
|
||||
})
|
||||
elif False: # Single Speaker
|
||||
basic_params.update({
|
||||
'dropout_prob': 0.5,
|
||||
|
||||
'attention_size': f(128),
|
||||
|
||||
'post_bank_channel_size': f(128),
|
||||
#'post_rnn_size': f(128),
|
||||
|
||||
'reduction_factor': 4,
|
||||
})
|
||||
elif False: # Single Speaker with generalization
|
||||
basic_params.update({
|
||||
'dropout_prob': 0.8,
|
||||
|
||||
'attention_size': f(256),
|
||||
|
||||
'dec_prenet_sizes': [f(256), f(128), f(64)],
|
||||
'post_bank_channel_size': f(128),
|
||||
'post_rnn_size': f(128),
|
||||
|
||||
'reduction_factor': 4,
|
||||
})
|
||||
|
||||
|
||||
basic_params.update({
|
||||
# Training
|
||||
'batch_size': 16,
|
||||
'adam_beta1': 0.9,
|
||||
'adam_beta2': 0.999,
|
||||
'use_fixed_test_inputs': False,
|
||||
|
||||
'initial_learning_rate': 0.002,
|
||||
'decay_learning_rate_mode': 0,
|
||||
'initial_data_greedy': True,
|
||||
'initial_phase_step': 8000,
|
||||
'main_data_greedy_factor': 0,
|
||||
'main_data': [''],
|
||||
'prioritize_loss': False,
|
||||
|
||||
'recognition_loss_coeff': 0.2,
|
||||
'ignore_recognition_level': 1, # 0: use all, 1: ignore only unmatched_alignment, 2: fully ignore recognition
|
||||
|
||||
# Eval
|
||||
'min_tokens': 50,
|
||||
'min_iters': 30,
|
||||
'max_iters': 200,
|
||||
'skip_inadequate': False,
|
||||
|
||||
'griffin_lim_iters': 60,
|
||||
'power': 1.5, # Power to raise magnitudes to prior to Griffin-Lim
|
||||
})
|
||||
|
||||
|
||||
# Default hyperparameters:
|
||||
hparams = tf.contrib.training.HParams(**basic_params)
|
||||
|
||||
|
||||
def hparams_debug_string():
|
||||
values = hparams.values()
|
||||
hp = [' %s: %s' % (name, values[name]) for name in sorted(values)]
|
||||
return 'Hyperparameters:\n' + '\n'.join(hp)
|
17
models/__init__.py
Normal file
17
models/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import os
|
||||
from glob import glob
|
||||
from .tacotron import Tacotron
|
||||
|
||||
|
||||
def create_model(hparams):
|
||||
return Tacotron(hparams)
|
||||
|
||||
|
||||
def get_most_recent_checkpoint(checkpoint_dir):
|
||||
checkpoint_paths = [path for path in glob("{}/*.ckpt-*.data-*".format(checkpoint_dir))]
|
||||
idxes = [int(os.path.basename(path).split('-')[1].split('.')[0]) for path in checkpoint_paths]
|
||||
|
||||
max_idx = max(idxes)
|
||||
lastest_checkpoint = os.path.join(checkpoint_dir, "model.ckpt-{}".format(max_idx))
|
||||
print(" [*] Found lastest checkpoint: {}".format(lastest_checkpoint))
|
||||
return lastest_checkpoint
|
73
models/helpers.py
Normal file
73
models/helpers.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
# Code based on https://github.com/keithito/tacotron/blob/master/models/tacotron.py
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.seq2seq import Helper
|
||||
|
||||
|
||||
# Adapted from tf.contrib.seq2seq.GreedyEmbeddingHelper
|
||||
class TacoTestHelper(Helper):
|
||||
def __init__(self, batch_size, output_dim, r):
|
||||
with tf.name_scope('TacoTestHelper'):
|
||||
self._batch_size = batch_size
|
||||
self._output_dim = output_dim
|
||||
self._end_token = tf.tile([0.0], [output_dim * r])
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def initialize(self, name=None):
|
||||
return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim))
|
||||
|
||||
def sample(self, time, outputs, state, name=None):
|
||||
return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them
|
||||
|
||||
def next_inputs(self, time, outputs, state, sample_ids, name=None):
|
||||
'''Stop on EOS. Otherwise, pass the last output as the next input and pass through state.'''
|
||||
with tf.name_scope('TacoTestHelper'):
|
||||
finished = tf.reduce_all(tf.equal(outputs, self._end_token), axis=1)
|
||||
# Feed last output frame as next input. outputs is [N, output_dim * r]
|
||||
next_inputs = outputs[:, -self._output_dim:]
|
||||
return (finished, next_inputs, state)
|
||||
|
||||
|
||||
class TacoTrainingHelper(Helper):
|
||||
def __init__(self, inputs, targets, output_dim, r, rnn_decoder_test_mode=False):
|
||||
# inputs is [N, T_in], targets is [N, T_out, D]
|
||||
with tf.name_scope('TacoTrainingHelper'):
|
||||
self._batch_size = tf.shape(inputs)[0]
|
||||
self._output_dim = output_dim
|
||||
self._rnn_decoder_test_mode = rnn_decoder_test_mode
|
||||
|
||||
# Feed every r-th target frame as input
|
||||
self._targets = targets[:, r-1::r, :]
|
||||
|
||||
# Use full length for every target because we don't want to mask the padding frames
|
||||
num_steps = tf.shape(self._targets)[1]
|
||||
self._lengths = tf.tile([num_steps], [self._batch_size])
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def initialize(self, name=None):
|
||||
return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim))
|
||||
|
||||
def sample(self, time, outputs, state, name=None):
|
||||
return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them
|
||||
|
||||
def next_inputs(self, time, outputs, state, sample_ids, name=None):
|
||||
with tf.name_scope(name or 'TacoTrainingHelper'):
|
||||
finished = (time + 1 >= self._lengths)
|
||||
if self._rnn_decoder_test_mode:
|
||||
next_inputs = outputs[:, -self._output_dim:]
|
||||
else:
|
||||
next_inputs = self._targets[:, time, :]
|
||||
return (finished, next_inputs, state)
|
||||
|
||||
|
||||
def _go_frames(batch_size, output_dim):
|
||||
'''Returns all-zero <GO> frames for a given batch size and output dimension'''
|
||||
return tf.tile([[0.0]], [batch_size, output_dim])
|
||||
|
131
models/modules.py
Normal file
131
models/modules.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
# Code based on https://github.com/keithito/tacotron/blob/master/models/tacotron.py
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.rnn import GRUCell
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper \
|
||||
import _bahdanau_score, _BaseAttentionMechanism, BahdanauAttention, \
|
||||
AttentionWrapper, AttentionWrapperState
|
||||
|
||||
|
||||
def get_embed(inputs, num_inputs, embed_size, name):
|
||||
embed_table = tf.get_variable(
|
||||
name, [num_inputs, embed_size], dtype=tf.float32,
|
||||
initializer=tf.truncated_normal_initializer(stddev=0.1))
|
||||
return tf.nn.embedding_lookup(embed_table, inputs)
|
||||
|
||||
|
||||
def prenet(inputs, is_training, layer_sizes, drop_prob, scope=None):
|
||||
x = inputs
|
||||
drop_rate = drop_prob if is_training else 0.0
|
||||
with tf.variable_scope(scope or 'prenet'):
|
||||
for i, size in enumerate(layer_sizes):
|
||||
dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_%d' % (i+1))
|
||||
x = tf.layers.dropout(dense, rate=drop_rate, name='dropout_%d' % (i+1))
|
||||
return x
|
||||
|
||||
def cbhg(inputs, input_lengths, is_training,
|
||||
bank_size, bank_channel_size,
|
||||
maxpool_width, highway_depth, rnn_size,
|
||||
proj_sizes, proj_width, scope,
|
||||
before_highway=None, encoder_rnn_init_state=None):
|
||||
|
||||
batch_size = tf.shape(inputs)[0]
|
||||
with tf.variable_scope(scope):
|
||||
with tf.variable_scope('conv_bank'):
|
||||
# Convolution bank: concatenate on the last axis
|
||||
# to stack channels from all convolutions
|
||||
conv_fn = lambda k: \
|
||||
conv1d(inputs, k, bank_channel_size,
|
||||
tf.nn.relu, is_training, 'conv1d_%d' % k)
|
||||
|
||||
conv_outputs = tf.concat(
|
||||
[conv_fn(k) for k in range(1, bank_size+1)], axis=-1,
|
||||
)
|
||||
|
||||
# Maxpooling:
|
||||
maxpool_output = tf.layers.max_pooling1d(
|
||||
conv_outputs,
|
||||
pool_size=maxpool_width,
|
||||
strides=1,
|
||||
padding='same')
|
||||
|
||||
# Two projection layers:
|
||||
proj_out = maxpool_output
|
||||
for idx, proj_size in enumerate(proj_sizes):
|
||||
activation_fn = None if idx == len(proj_sizes) - 1 else tf.nn.relu
|
||||
proj_out = conv1d(
|
||||
proj_out, proj_width, proj_size, activation_fn,
|
||||
is_training, 'proj_{}'.format(idx + 1))
|
||||
|
||||
# Residual connection:
|
||||
if before_highway is not None:
|
||||
expanded_before_highway = tf.expand_dims(before_highway, [1])
|
||||
tiled_before_highway = tf.tile(
|
||||
expanded_before_highway, [1, tf.shape(proj_out)[1], 1])
|
||||
|
||||
highway_input = proj_out + inputs + tiled_before_highway
|
||||
else:
|
||||
highway_input = proj_out + inputs
|
||||
|
||||
# Handle dimensionality mismatch:
|
||||
if highway_input.shape[2] != rnn_size:
|
||||
highway_input = tf.layers.dense(highway_input, rnn_size)
|
||||
|
||||
# 4-layer HighwayNet:
|
||||
for idx in range(highway_depth):
|
||||
highway_input = highwaynet(highway_input, 'highway_%d' % (idx+1))
|
||||
|
||||
rnn_input = highway_input
|
||||
|
||||
# Bidirectional RNN
|
||||
if encoder_rnn_init_state is not None:
|
||||
initial_state_fw, initial_state_bw = \
|
||||
tf.split(encoder_rnn_init_state, 2, 1)
|
||||
else:
|
||||
initial_state_fw, initial_state_bw = None, None
|
||||
|
||||
cell_fw, cell_bw = GRUCell(rnn_size), GRUCell(rnn_size)
|
||||
outputs, states = tf.nn.bidirectional_dynamic_rnn(
|
||||
cell_fw, cell_bw,
|
||||
rnn_input,
|
||||
sequence_length=input_lengths,
|
||||
initial_state_fw=initial_state_fw,
|
||||
initial_state_bw=initial_state_bw,
|
||||
dtype=tf.float32)
|
||||
return tf.concat(outputs, axis=2) # Concat forward and backward
|
||||
|
||||
|
||||
def batch_tile(tensor, batch_size):
|
||||
expaneded_tensor = tf.expand_dims(tensor, [0])
|
||||
return tf.tile(expaneded_tensor, \
|
||||
[batch_size] + [1 for _ in tensor.get_shape()])
|
||||
|
||||
|
||||
def highwaynet(inputs, scope):
|
||||
highway_dim = int(inputs.get_shape()[-1])
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
H = tf.layers.dense(
|
||||
inputs,
|
||||
units=highway_dim,
|
||||
activation=tf.nn.relu,
|
||||
name='H')
|
||||
T = tf.layers.dense(
|
||||
inputs,
|
||||
units=highway_dim,
|
||||
activation=tf.nn.sigmoid,
|
||||
name='T',
|
||||
bias_initializer=tf.constant_initializer(-1.0))
|
||||
return H * T + inputs * (1.0 - T)
|
||||
|
||||
|
||||
def conv1d(inputs, kernel_size, channels, activation, is_training, scope):
|
||||
with tf.variable_scope(scope):
|
||||
conv1d_output = tf.layers.conv1d(
|
||||
inputs,
|
||||
filters=channels,
|
||||
kernel_size=kernel_size,
|
||||
activation=activation,
|
||||
padding='same')
|
||||
return tf.layers.batch_normalization(conv1d_output, training=is_training)
|
418
models/rnn_wrappers.py
Normal file
418
models/rnn_wrappers.py
Normal file
|
@ -0,0 +1,418 @@
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.rnn import RNNCell
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.contrib.data.python.util import nest
|
||||
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper \
|
||||
import _bahdanau_score, _BaseAttentionMechanism, BahdanauAttention, \
|
||||
AttentionWrapperState, AttentionMechanism
|
||||
|
||||
from .modules import prenet
|
||||
|
||||
_zero_state_tensors = rnn_cell_impl._zero_state_tensors
|
||||
|
||||
|
||||
|
||||
class AttentionWrapper(RNNCell):
|
||||
"""Wraps another `RNNCell` with attention.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cell,
|
||||
attention_mechanism,
|
||||
is_manual_attention,
|
||||
manual_alignments,
|
||||
attention_layer_size=None,
|
||||
alignment_history=False,
|
||||
cell_input_fn=None,
|
||||
output_attention=True,
|
||||
initial_cell_state=None,
|
||||
name=None):
|
||||
"""Construct the `AttentionWrapper`.
|
||||
Args:
|
||||
cell: An instance of `RNNCell`.
|
||||
attention_mechanism: A list of `AttentionMechanism` instances or a single
|
||||
instance.
|
||||
attention_layer_size: A list of Python integers or a single Python
|
||||
integer, the depth of the attention (output) layer(s). If None
|
||||
(default), use the context as attention at each time step. Otherwise,
|
||||
feed the context and cell output into the attention layer to generate
|
||||
attention at each time step. If attention_mechanism is a list,
|
||||
attention_layer_size must be a list of the same length.
|
||||
alignment_history: Python boolean, whether to store alignment history
|
||||
from all time steps in the final output state (currently stored as a
|
||||
time major `TensorArray` on which you must call `stack()`).
|
||||
cell_input_fn: (optional) A `callable`. The default is:
|
||||
`lambda inputs, attention: array_tf.concat([inputs, attention], -1)`.
|
||||
output_attention: Python bool. If `True` (default), the output at each
|
||||
time step is the attention value. This is the behavior of Luong-style
|
||||
attention mechanisms. If `False`, the output at each time step is
|
||||
the output of `cell`. This is the beahvior of Bhadanau-style
|
||||
attention mechanisms. In both cases, the `attention` tensor is
|
||||
propagated to the next time step via the state and is used there.
|
||||
This flag only controls whether the attention mechanism is propagated
|
||||
up to the next cell in an RNN stack or to the top RNN output.
|
||||
initial_cell_state: The initial state value to use for the cell when
|
||||
the user calls `zero_state()`. Note that if this value is provided
|
||||
now, and the user uses a `batch_size` argument of `zero_state` which
|
||||
does not match the batch size of `initial_cell_state`, proper
|
||||
behavior is not guaranteed.
|
||||
name: Name to use when creating tf.
|
||||
Raises:
|
||||
TypeError: `attention_layer_size` is not None and (`attention_mechanism`
|
||||
is a list but `attention_layer_size` is not; or vice versa).
|
||||
ValueError: if `attention_layer_size` is not None, `attention_mechanism`
|
||||
is a list, and its length does not match that of `attention_layer_size`.
|
||||
"""
|
||||
super(AttentionWrapper, self).__init__(name=name)
|
||||
|
||||
self.is_manual_attention = is_manual_attention
|
||||
self.manual_alignments = manual_alignments
|
||||
|
||||
if isinstance(attention_mechanism, (list, tuple)):
|
||||
self._is_multi = True
|
||||
attention_mechanisms = attention_mechanism
|
||||
for attention_mechanism in attention_mechanisms:
|
||||
if not isinstance(attention_mechanism, AttentionMechanism):
|
||||
raise TypeError(
|
||||
"attention_mechanism must contain only instances of "
|
||||
"AttentionMechanism, saw type: %s"
|
||||
% type(attention_mechanism).__name__)
|
||||
else:
|
||||
self._is_multi = False
|
||||
if not isinstance(attention_mechanism, AttentionMechanism):
|
||||
raise TypeError(
|
||||
"attention_mechanism must be an AttentionMechanism or list of "
|
||||
"multiple AttentionMechanism instances, saw type: %s"
|
||||
% type(attention_mechanism).__name__)
|
||||
attention_mechanisms = (attention_mechanism,)
|
||||
|
||||
if cell_input_fn is None:
|
||||
cell_input_fn = (
|
||||
lambda inputs, attention: tf.concat([inputs, attention], -1))
|
||||
else:
|
||||
if not callable(cell_input_fn):
|
||||
raise TypeError(
|
||||
"cell_input_fn must be callable, saw type: %s"
|
||||
% type(cell_input_fn).__name__)
|
||||
|
||||
if attention_layer_size is not None:
|
||||
attention_layer_sizes = tuple(
|
||||
attention_layer_size
|
||||
if isinstance(attention_layer_size, (list, tuple))
|
||||
else (attention_layer_size,))
|
||||
if len(attention_layer_sizes) != len(attention_mechanisms):
|
||||
raise ValueError(
|
||||
"If provided, attention_layer_size must contain exactly one "
|
||||
"integer per attention_mechanism, saw: %d vs %d"
|
||||
% (len(attention_layer_sizes), len(attention_mechanisms)))
|
||||
self._attention_layers = tuple(
|
||||
layers_core.Dense(
|
||||
attention_layer_size, name="attention_layer", use_bias=False)
|
||||
for attention_layer_size in attention_layer_sizes)
|
||||
self._attention_layer_size = sum(attention_layer_sizes)
|
||||
else:
|
||||
self._attention_layers = None
|
||||
self._attention_layer_size = sum(
|
||||
attention_mechanism.values.get_shape()[-1].value
|
||||
for attention_mechanism in attention_mechanisms)
|
||||
|
||||
self._cell = cell
|
||||
self._attention_mechanisms = attention_mechanisms
|
||||
self._cell_input_fn = cell_input_fn
|
||||
self._output_attention = output_attention
|
||||
self._alignment_history = alignment_history
|
||||
with tf.name_scope(name, "AttentionWrapperInit"):
|
||||
if initial_cell_state is None:
|
||||
self._initial_cell_state = None
|
||||
else:
|
||||
final_state_tensor = nest.flatten(initial_cell_state)[-1]
|
||||
state_batch_size = (
|
||||
final_state_tensor.shape[0].value
|
||||
or tf.shape(final_state_tensor)[0])
|
||||
error_message = (
|
||||
"When constructing AttentionWrapper %s: " % self._base_name +
|
||||
"Non-matching batch sizes between the memory "
|
||||
"(encoder output) and initial_cell_state. Are you using "
|
||||
"the BeamSearchDecoder? You may need to tile your initial state "
|
||||
"via the tf.contrib.seq2seq.tile_batch function with argument "
|
||||
"multiple=beam_width.")
|
||||
with tf.control_dependencies(
|
||||
self._batch_size_checks(state_batch_size, error_message)):
|
||||
self._initial_cell_state = nest.map_structure(
|
||||
lambda s: tf.identity(s, name="check_initial_cell_state"),
|
||||
initial_cell_state)
|
||||
|
||||
def _batch_size_checks(self, batch_size, error_message):
|
||||
return [tf.assert_equal(batch_size,
|
||||
attention_mechanism.batch_size,
|
||||
message=error_message)
|
||||
for attention_mechanism in self._attention_mechanisms]
|
||||
|
||||
def _item_or_tuple(self, seq):
|
||||
"""Returns `seq` as tuple or the singular element.
|
||||
Which is returned is determined by how the AttentionMechanism(s) were passed
|
||||
to the constructor.
|
||||
Args:
|
||||
seq: A non-empty sequence of items or generator.
|
||||
Returns:
|
||||
Either the values in the sequence as a tuple if AttentionMechanism(s)
|
||||
were passed to the constructor as a sequence or the singular element.
|
||||
"""
|
||||
t = tuple(seq)
|
||||
if self._is_multi:
|
||||
return t
|
||||
else:
|
||||
return t[0]
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
if self._output_attention:
|
||||
return self._attention_layer_size
|
||||
else:
|
||||
return self._cell.output_size
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return AttentionWrapperState(
|
||||
cell_state=self._cell.state_size,
|
||||
time=tf.TensorShape([]),
|
||||
attention=self._attention_layer_size,
|
||||
alignments=self._item_or_tuple(
|
||||
a.alignments_size for a in self._attention_mechanisms),
|
||||
alignment_history=self._item_or_tuple(
|
||||
() for _ in self._attention_mechanisms)) # sometimes a TensorArray
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
if self._initial_cell_state is not None:
|
||||
cell_state = self._initial_cell_state
|
||||
else:
|
||||
cell_state = self._cell.zero_state(batch_size, dtype)
|
||||
error_message = (
|
||||
"When calling zero_state of AttentionWrapper %s: " % self._base_name +
|
||||
"Non-matching batch sizes between the memory "
|
||||
"(encoder output) and the requested batch size. Are you using "
|
||||
"the BeamSearchDecoder? If so, make sure your encoder output has "
|
||||
"been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
|
||||
"the batch_size= argument passed to zero_state is "
|
||||
"batch_size * beam_width.")
|
||||
with tf.control_dependencies(
|
||||
self._batch_size_checks(batch_size, error_message)):
|
||||
cell_state = nest.map_structure(
|
||||
lambda s: tf.identity(s, name="checked_cell_state"),
|
||||
cell_state)
|
||||
|
||||
return AttentionWrapperState(
|
||||
cell_state=cell_state,
|
||||
time=tf.zeros([], dtype=tf.int32),
|
||||
attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype),
|
||||
alignments=self._item_or_tuple(
|
||||
attention_mechanism.initial_alignments(batch_size, dtype)
|
||||
for attention_mechanism in self._attention_mechanisms),
|
||||
alignment_history=self._item_or_tuple(
|
||||
tf.TensorArray(dtype=dtype, size=0, dynamic_size=True)
|
||||
if self._alignment_history else ()
|
||||
for _ in self._attention_mechanisms))
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Perform a step of attention-wrapped RNN.
|
||||
- Step 1: Mix the `inputs` and previous step's `attention` output via
|
||||
`cell_input_fn`.
|
||||
- Step 2: Call the wrapped `cell` with this input and its previous state.
|
||||
- Step 3: Score the cell's output with `attention_mechanism`.
|
||||
- Step 4: Calculate the alignments by passing the score through the
|
||||
`normalizer`.
|
||||
- Step 5: Calculate the context vector as the inner product between the
|
||||
alignments and the attention_mechanism's values (memory).
|
||||
- Step 6: Calculate the attention output by concatenating the cell output
|
||||
and context through the attention layer (a linear layer with
|
||||
`attention_layer_size` outputs).
|
||||
Args:
|
||||
inputs: (Possibly nested tuple of) Tensor, the input at this time step.
|
||||
state: An instance of `AttentionWrapperState` containing
|
||||
tensors from the previous time step.
|
||||
Returns:
|
||||
A tuple `(attention_or_cell_output, next_state)`, where:
|
||||
- `attention_or_cell_output` depending on `output_attention`.
|
||||
- `next_state` is an instance of `AttentionWrapperState`
|
||||
containing the state calculated at this time step.
|
||||
Raises:
|
||||
TypeError: If `state` is not an instance of `AttentionWrapperState`.
|
||||
"""
|
||||
if not isinstance(state, AttentionWrapperState):
|
||||
raise TypeError("Expected state to be instance of AttentionWrapperState. "
|
||||
"Received type %s instead." % type(state))
|
||||
|
||||
# Step 1: Calculate the true inputs to the cell based on the
|
||||
# previous attention value.
|
||||
cell_inputs = self._cell_input_fn(inputs, state.attention)
|
||||
cell_state = state.cell_state
|
||||
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
|
||||
|
||||
cell_batch_size = (
|
||||
cell_output.shape[0].value or tf.shape(cell_output)[0])
|
||||
error_message = (
|
||||
"When applying AttentionWrapper %s: " % self.name +
|
||||
"Non-matching batch sizes between the memory "
|
||||
"(encoder output) and the query (decoder output). Are you using "
|
||||
"the BeamSearchDecoder? You may need to tile your memory input via "
|
||||
"the tf.contrib.seq2seq.tile_batch function with argument "
|
||||
"multiple=beam_width.")
|
||||
with tf.control_dependencies(
|
||||
self._batch_size_checks(cell_batch_size, error_message)):
|
||||
cell_output = tf.identity(
|
||||
cell_output, name="checked_cell_output")
|
||||
|
||||
if self._is_multi:
|
||||
previous_alignments = state.alignments
|
||||
previous_alignment_history = state.alignment_history
|
||||
else:
|
||||
previous_alignments = [state.alignments]
|
||||
previous_alignment_history = [state.alignment_history]
|
||||
|
||||
all_alignments = []
|
||||
all_attentions = []
|
||||
all_histories = []
|
||||
|
||||
for i, attention_mechanism in enumerate(self._attention_mechanisms):
|
||||
attention, alignments = _compute_attention(
|
||||
attention_mechanism, cell_output, previous_alignments[i],
|
||||
self._attention_layers[i] if self._attention_layers else None,
|
||||
self.is_manual_attention, self.manual_alignments, state.time)
|
||||
|
||||
alignment_history = previous_alignment_history[i].write(
|
||||
state.time, alignments) if self._alignment_history else ()
|
||||
|
||||
all_alignments.append(alignments)
|
||||
all_histories.append(alignment_history)
|
||||
all_attentions.append(attention)
|
||||
|
||||
attention = tf.concat(all_attentions, 1)
|
||||
next_state = AttentionWrapperState(
|
||||
time=state.time + 1,
|
||||
cell_state=next_cell_state,
|
||||
attention=attention,
|
||||
alignments=self._item_or_tuple(all_alignments),
|
||||
alignment_history=self._item_or_tuple(all_histories))
|
||||
|
||||
if self._output_attention:
|
||||
return attention, next_state
|
||||
else:
|
||||
return cell_output, next_state
|
||||
|
||||
def _compute_attention(
|
||||
attention_mechanism, cell_output, previous_alignments,
|
||||
attention_layer, is_manual_attention, manual_alignments, time):
|
||||
|
||||
computed_alignments = attention_mechanism(
|
||||
cell_output, previous_alignments=previous_alignments)
|
||||
batch_size, max_time = \
|
||||
tf.shape(computed_alignments)[0], tf.shape(computed_alignments)[1]
|
||||
|
||||
alignments = tf.cond(
|
||||
is_manual_attention,
|
||||
lambda: manual_alignments[:, time, :],
|
||||
lambda: computed_alignments,
|
||||
)
|
||||
|
||||
#alignments = tf.one_hot(tf.zeros((batch_size,), dtype=tf.int32), max_time, dtype=tf.float32)
|
||||
|
||||
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
|
||||
expanded_alignments = tf.expand_dims(alignments, 1)
|
||||
|
||||
# Context is the inner product of alignments and values along the
|
||||
# memory time dimension.
|
||||
# alignments shape is
|
||||
# [batch_size, 1, memory_time]
|
||||
# attention_mechanism.values shape is
|
||||
# [batch_size, memory_time, memory_size]
|
||||
# the batched matmul is over memory_time, so the output shape is
|
||||
# [batch_size, 1, memory_size].
|
||||
# we then squeeze out the singleton dim.
|
||||
context = tf.matmul(expanded_alignments, attention_mechanism.values)
|
||||
context = tf.squeeze(context, [1])
|
||||
|
||||
if attention_layer is not None:
|
||||
attention = attention_layer(tf.concat([cell_output, context], 1))
|
||||
else:
|
||||
attention = context
|
||||
|
||||
return attention, alignments
|
||||
|
||||
|
||||
class DecoderPrenetWrapper(RNNCell):
|
||||
'''Runs RNN inputs through a prenet before sending them to the cell.'''
|
||||
def __init__(
|
||||
self, cell, embed_to_concat,
|
||||
is_training, prenet_sizes, dropout_prob):
|
||||
|
||||
super(DecoderPrenetWrapper, self).__init__()
|
||||
self._is_training = is_training
|
||||
|
||||
self._cell = cell
|
||||
self._embed_to_concat = embed_to_concat
|
||||
|
||||
self.prenet_sizes = prenet_sizes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._cell.state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._cell.output_size
|
||||
|
||||
def call(self, inputs, state):
|
||||
prenet_out = prenet(
|
||||
inputs, self._is_training,
|
||||
self.prenet_sizes, self.dropout_prob, scope='decoder_prenet')
|
||||
|
||||
if self._embed_to_concat is not None:
|
||||
concat_out = tf.concat(
|
||||
[prenet_out, self._embed_to_concat],
|
||||
axis=-1, name='speaker_concat')
|
||||
return self._cell(concat_out, state)
|
||||
else:
|
||||
return self._cell(prenet_out, state)
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
|
||||
|
||||
class ConcatOutputAndAttentionWrapper(RNNCell):
|
||||
'''Concatenates RNN cell output with the attention context vector.
|
||||
|
||||
This is expected to wrap a cell wrapped with an AttentionWrapper constructed with
|
||||
attention_layer_size=None and output_attention=False. Such a cell's state will include an
|
||||
"attention" field that is the context vector.
|
||||
'''
|
||||
def __init__(self, cell, embed_to_concat):
|
||||
super(ConcatOutputAndAttentionWrapper, self).__init__()
|
||||
self._cell = cell
|
||||
self._embed_to_concat = embed_to_concat
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._cell.state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._cell.output_size + self._cell.state_size.attention
|
||||
|
||||
def call(self, inputs, state):
|
||||
output, res_state = self._cell(inputs, state)
|
||||
|
||||
if self._embed_to_concat is not None:
|
||||
tensors = [
|
||||
output, res_state.attention,
|
||||
self._embed_to_concat,
|
||||
]
|
||||
return tf.concat(tensors, axis=-1), res_state
|
||||
else:
|
||||
return tf.concat([output, res_state.attention], axis=-1), res_state
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
343
models/tacotron.py
Normal file
343
models/tacotron.py
Normal file
|
@ -0,0 +1,343 @@
|
|||
# Code based on https://github.com/keithito/tacotron/blob/master/models/tacotron.py
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, BahdanauMonotonicAttention
|
||||
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper
|
||||
|
||||
from utils.infolog import log
|
||||
from text.symbols import symbols
|
||||
|
||||
from .modules import *
|
||||
from .helpers import TacoTestHelper, TacoTrainingHelper
|
||||
from .rnn_wrappers import AttentionWrapper, DecoderPrenetWrapper, ConcatOutputAndAttentionWrapper
|
||||
|
||||
|
||||
class Tacotron():
|
||||
def __init__(self, hparams):
|
||||
self._hparams = hparams
|
||||
|
||||
|
||||
def initialize(
|
||||
self, inputs, input_lengths, num_speakers, speaker_id,
|
||||
mel_targets=None, linear_targets=None, loss_coeff=None,
|
||||
rnn_decoder_test_mode=False, is_randomly_initialized=False,
|
||||
):
|
||||
is_training = linear_targets is not None
|
||||
self.is_randomly_initialized = is_randomly_initialized
|
||||
|
||||
with tf.variable_scope('inference') as scope:
|
||||
hp = self._hparams
|
||||
batch_size = tf.shape(inputs)[0]
|
||||
|
||||
# Embeddings
|
||||
char_embed_table = tf.get_variable(
|
||||
'embedding', [len(symbols), hp.embedding_size], dtype=tf.float32,
|
||||
initializer=tf.truncated_normal_initializer(stddev=0.5))
|
||||
# [N, T_in, embedding_size]
|
||||
char_embedded_inputs = \
|
||||
tf.nn.embedding_lookup(char_embed_table, inputs)
|
||||
|
||||
self.num_speakers = num_speakers
|
||||
if self.num_speakers > 1:
|
||||
if hp.speaker_embedding_size != 1:
|
||||
speaker_embed_table = tf.get_variable(
|
||||
'speaker_embedding',
|
||||
[self.num_speakers, hp.speaker_embedding_size], dtype=tf.float32,
|
||||
initializer=tf.truncated_normal_initializer(stddev=0.5))
|
||||
# [N, T_in, speaker_embedding_size]
|
||||
speaker_embed = tf.nn.embedding_lookup(speaker_embed_table, speaker_id)
|
||||
|
||||
if hp.model_type == 'deepvoice':
|
||||
if hp.speaker_embedding_size == 1:
|
||||
before_highway = get_embed(
|
||||
speaker_id, self.num_speakers,
|
||||
hp.enc_prenet_sizes[-1], "before_highway")
|
||||
encoder_rnn_init_state = get_embed(
|
||||
speaker_id, self.num_speakers,
|
||||
hp.enc_rnn_size * 2, "encoder_rnn_init_state")
|
||||
|
||||
attention_rnn_init_state = get_embed(
|
||||
speaker_id, self.num_speakers,
|
||||
hp.attention_state_size, "attention_rnn_init_state")
|
||||
decoder_rnn_init_states = [get_embed(
|
||||
speaker_id, self.num_speakers,
|
||||
hp.dec_rnn_size, "decoder_rnn_init_states{}".format(idx + 1)) \
|
||||
for idx in range(hp.dec_layer_num)]
|
||||
else:
|
||||
deep_dense = lambda x, dim: \
|
||||
tf.layers.dense(x, dim, activation=tf.nn.softsign)
|
||||
|
||||
before_highway = deep_dense(
|
||||
speaker_embed, hp.enc_prenet_sizes[-1])
|
||||
encoder_rnn_init_state = deep_dense(
|
||||
speaker_embed, hp.enc_rnn_size * 2)
|
||||
|
||||
attention_rnn_init_state = deep_dense(
|
||||
speaker_embed, hp.attention_state_size)
|
||||
decoder_rnn_init_states = [deep_dense(
|
||||
speaker_embed, hp.dec_rnn_size) for _ in range(hp.dec_layer_num)]
|
||||
|
||||
speaker_embed = None # deepvoice does not use speaker_embed directly
|
||||
elif hp.model_type == 'simple':
|
||||
before_highway = None
|
||||
encoder_rnn_init_state = None
|
||||
attention_rnn_init_state = None
|
||||
decoder_rnn_init_states = None
|
||||
else:
|
||||
raise Exception(" [!] Unkown multi-speaker model type: {}".format(hp.model_type))
|
||||
else:
|
||||
speaker_embed = None
|
||||
before_highway = None
|
||||
encoder_rnn_init_state = None
|
||||
attention_rnn_init_state = None
|
||||
decoder_rnn_init_states = None
|
||||
|
||||
##############
|
||||
# Encoder
|
||||
##############
|
||||
|
||||
# [N, T_in, enc_prenet_sizes[-1]]
|
||||
prenet_outputs = prenet(char_embedded_inputs, is_training,
|
||||
hp.enc_prenet_sizes, hp.dropout_prob,
|
||||
scope='prenet')
|
||||
|
||||
encoder_outputs = cbhg(
|
||||
prenet_outputs, input_lengths, is_training,
|
||||
hp.enc_bank_size, hp.enc_bank_channel_size,
|
||||
hp.enc_maxpool_width, hp.enc_highway_depth, hp.enc_rnn_size,
|
||||
hp.enc_proj_sizes, hp.enc_proj_width,
|
||||
scope="encoder_cbhg",
|
||||
before_highway=before_highway,
|
||||
encoder_rnn_init_state=encoder_rnn_init_state)
|
||||
|
||||
|
||||
##############
|
||||
# Attention
|
||||
##############
|
||||
|
||||
# For manaul control of attention
|
||||
self.is_manual_attention = tf.placeholder(
|
||||
tf.bool, shape=(), name='is_manual_attention',
|
||||
)
|
||||
self.manual_alignments = tf.placeholder(
|
||||
tf.float32, shape=[None, None, None], name="manual_alignments",
|
||||
)
|
||||
|
||||
dec_prenet_outputs = DecoderPrenetWrapper(
|
||||
GRUCell(hp.attention_state_size),
|
||||
speaker_embed,
|
||||
is_training, hp.dec_prenet_sizes, hp.dropout_prob)
|
||||
|
||||
if hp.attention_type == 'bah_mon':
|
||||
attention_mechanism = BahdanauMonotonicAttention(
|
||||
hp.attention_size, encoder_outputs)
|
||||
elif hp.attention_type == 'bah_norm':
|
||||
attention_mechanism = BahdanauAttention(
|
||||
hp.attention_size, encoder_outputs, normalize=True)
|
||||
elif hp.attention_type == 'luong_scaled':
|
||||
attention_mechanism = LuongAttention(
|
||||
hp.attention_size, encoder_outputs, scale=True)
|
||||
elif hp.attention_type == 'luong':
|
||||
attention_mechanism = LuongAttention(
|
||||
hp.attention_size, encoder_outputs)
|
||||
elif hp.attention_type == 'bah':
|
||||
attention_mechanism = BahdanauAttention(
|
||||
hp.attention_size, encoder_outputs)
|
||||
elif hp.attention_type.startswith('ntm2'):
|
||||
shift_width = int(hp.attention_type.split('-')[-1])
|
||||
attention_mechanism = NTMAttention2(
|
||||
hp.attention_size, encoder_outputs, shift_width=shift_width)
|
||||
else:
|
||||
raise Exception(" [!] Unkown attention type: {}".format(hp.attention_type))
|
||||
|
||||
attention_cell = AttentionWrapper(
|
||||
dec_prenet_outputs,
|
||||
attention_mechanism,
|
||||
self.is_manual_attention,
|
||||
self.manual_alignments,
|
||||
initial_cell_state=attention_rnn_init_state,
|
||||
alignment_history=True,
|
||||
output_attention=False
|
||||
)
|
||||
|
||||
# Concatenate attention context vector and RNN cell output into a 512D vector.
|
||||
# [N, T_in, attention_size+attention_state_size]
|
||||
concat_cell = ConcatOutputAndAttentionWrapper(
|
||||
attention_cell, embed_to_concat=speaker_embed)
|
||||
|
||||
# Decoder (layers specified bottom to top):
|
||||
cells = [OutputProjectionWrapper(concat_cell, hp.dec_rnn_size)]
|
||||
for _ in range(hp.dec_layer_num):
|
||||
cells.append(ResidualWrapper(GRUCell(hp.dec_rnn_size)))
|
||||
|
||||
# [N, T_in, 256]
|
||||
decoder_cell = MultiRNNCell(cells, state_is_tuple=True)
|
||||
|
||||
# Project onto r mel spectrograms (predict r outputs at each RNN step):
|
||||
output_cell = OutputProjectionWrapper(
|
||||
decoder_cell, hp.num_mels * hp.reduction_factor)
|
||||
decoder_init_state = output_cell.zero_state(
|
||||
batch_size=batch_size, dtype=tf.float32)
|
||||
|
||||
if hp.model_type == "deepvoice":
|
||||
# decoder_init_state[0] : AttentionWrapperState
|
||||
# = cell_state + attention + time + alignments + alignment_history
|
||||
# decoder_init_state[0][0] = attention_rnn_init_state (already applied)
|
||||
decoder_init_state = list(decoder_init_state)
|
||||
|
||||
for idx, cell in enumerate(decoder_rnn_init_states):
|
||||
shape1 = decoder_init_state[idx + 1].get_shape().as_list()
|
||||
shape2 = cell.get_shape().as_list()
|
||||
if shape1 != shape2:
|
||||
raise Exception(" [!] Shape {} and {} should be equal". \
|
||||
format(shape1, shape2))
|
||||
decoder_init_state[idx + 1] = cell
|
||||
|
||||
decoder_init_state = tuple(decoder_init_state)
|
||||
|
||||
if is_training:
|
||||
helper = TacoTrainingHelper(
|
||||
inputs, mel_targets, hp.num_mels, hp.reduction_factor,
|
||||
rnn_decoder_test_mode)
|
||||
else:
|
||||
helper = TacoTestHelper(
|
||||
batch_size, hp.num_mels, hp.reduction_factor)
|
||||
|
||||
(decoder_outputs, _), final_decoder_state, _ = \
|
||||
tf.contrib.seq2seq.dynamic_decode(
|
||||
BasicDecoder(output_cell, helper, decoder_init_state),
|
||||
maximum_iterations=hp.max_iters)
|
||||
|
||||
# [N, T_out, M]
|
||||
mel_outputs = tf.reshape(
|
||||
decoder_outputs, [batch_size, -1, hp.num_mels])
|
||||
|
||||
# Add post-processing CBHG:
|
||||
# [N, T_out, 256]
|
||||
#post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training)
|
||||
post_outputs = cbhg(
|
||||
mel_outputs, None, is_training,
|
||||
hp.post_bank_size, hp.post_bank_channel_size,
|
||||
hp.post_maxpool_width, hp.post_highway_depth, hp.post_rnn_size,
|
||||
hp.post_proj_sizes, hp.post_proj_width,
|
||||
scope='post_cbhg')
|
||||
|
||||
if speaker_embed is not None and hp.model_type == 'simple':
|
||||
expanded_speaker_emb = tf.expand_dims(speaker_embed, [1])
|
||||
tiled_speaker_embedding = tf.tile(
|
||||
expanded_speaker_emb, [1, tf.shape(post_outputs)[1], 1])
|
||||
|
||||
# [N, T_out, 256 + alpha]
|
||||
post_outputs = \
|
||||
tf.concat([tiled_speaker_embedding, post_outputs], axis=-1)
|
||||
|
||||
linear_outputs = tf.layers.dense(post_outputs, hp.num_freq) # [N, T_out, F]
|
||||
|
||||
# Grab alignments from the final decoder state:
|
||||
alignments = tf.transpose(
|
||||
final_decoder_state[0].alignment_history.stack(), [1, 2, 0])
|
||||
|
||||
|
||||
self.inputs = inputs
|
||||
self.speaker_id = speaker_id
|
||||
self.input_lengths = input_lengths
|
||||
self.loss_coeff = loss_coeff
|
||||
self.mel_outputs = mel_outputs
|
||||
self.linear_outputs = linear_outputs
|
||||
self.alignments = alignments
|
||||
self.mel_targets = mel_targets
|
||||
self.linear_targets = linear_targets
|
||||
self.final_decoder_state = final_decoder_state
|
||||
|
||||
log('='*40)
|
||||
log(' model_type: %s' % hp.model_type)
|
||||
log('='*40)
|
||||
|
||||
log('Initialized Tacotron model. Dimensions: ')
|
||||
log(' embedding: %d' % char_embedded_inputs.shape[-1])
|
||||
if speaker_embed is not None:
|
||||
log(' speaker embedding: %d' % speaker_embed.shape[-1])
|
||||
else:
|
||||
log(' speaker embedding: None')
|
||||
log(' prenet out: %d' % prenet_outputs.shape[-1])
|
||||
log(' encoder out: %d' % encoder_outputs.shape[-1])
|
||||
log(' attention out: %d' % attention_cell.output_size)
|
||||
log(' concat attn & out: %d' % concat_cell.output_size)
|
||||
log(' decoder cell out: %d' % decoder_cell.output_size)
|
||||
log(' decoder out (%d frames): %d' % (hp.reduction_factor, decoder_outputs.shape[-1]))
|
||||
log(' decoder out (1 frame): %d' % mel_outputs.shape[-1])
|
||||
log(' postnet out: %d' % post_outputs.shape[-1])
|
||||
log(' linear out: %d' % linear_outputs.shape[-1])
|
||||
|
||||
|
||||
def add_loss(self):
|
||||
'''Adds loss to the model. Sets "loss" field. initialize must have been called.'''
|
||||
with tf.variable_scope('loss') as scope:
|
||||
hp = self._hparams
|
||||
mel_loss = tf.abs(self.mel_targets - self.mel_outputs)
|
||||
|
||||
l1 = tf.abs(self.linear_targets - self.linear_outputs)
|
||||
expanded_loss_coeff = tf.expand_dims(
|
||||
tf.expand_dims(self.loss_coeff, [-1]), [-1])
|
||||
|
||||
if hp.prioritize_loss:
|
||||
# Prioritize loss for frequencies.
|
||||
upper_priority_freq = int(5000 / (hp.sample_rate * 0.5) * hp.num_freq)
|
||||
lower_priority_freq = int(165 / (hp.sample_rate * 0.5) * hp.num_freq)
|
||||
|
||||
l1_priority= l1[:,:,lower_priority_freq:upper_priority_freq]
|
||||
|
||||
self.loss = tf.reduce_mean(mel_loss * expanded_loss_coeff) + \
|
||||
0.5 * tf.reduce_mean(l1 * expanded_loss_coeff) + \
|
||||
0.5 * tf.reduce_mean(l1_priority * expanded_loss_coeff)
|
||||
self.linear_loss = tf.reduce_mean(
|
||||
0.5 * (tf.reduce_mean(l1) + tf.reduce_mean(l1_priority)))
|
||||
else:
|
||||
self.loss = tf.reduce_mean(mel_loss * expanded_loss_coeff) + \
|
||||
tf.reduce_mean(l1 * expanded_loss_coeff)
|
||||
self.linear_loss = tf.reduce_mean(l1)
|
||||
|
||||
self.mel_loss = tf.reduce_mean(mel_loss)
|
||||
self.loss_without_coeff = self.mel_loss + self.linear_loss
|
||||
|
||||
|
||||
def add_optimizer(self, global_step):
|
||||
'''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss must have been called.
|
||||
|
||||
Args:
|
||||
global_step: int32 scalar Tensor representing current global step in training
|
||||
'''
|
||||
with tf.variable_scope('optimizer') as scope:
|
||||
hp = self._hparams
|
||||
|
||||
step = tf.cast(global_step + 1, dtype=tf.float32)
|
||||
|
||||
if hp.decay_learning_rate_mode == 0:
|
||||
if self.is_randomly_initialized:
|
||||
warmup_steps = 4000.0
|
||||
else:
|
||||
warmup_steps = 40000.0
|
||||
self.learning_rate = hp.initial_learning_rate * warmup_steps**0.5 * \
|
||||
tf.minimum(step * warmup_steps**-1.5, step**-0.5)
|
||||
elif hp.decay_learning_rate_mode == 1:
|
||||
self.learning_rate = hp.initial_learning_rate * \
|
||||
tf.train.exponential_decay(1., step, 3000, 0.95)
|
||||
|
||||
optimizer = tf.train.AdamOptimizer(self.learning_rate, hp.adam_beta1, hp.adam_beta2)
|
||||
gradients, variables = zip(*optimizer.compute_gradients(self.loss))
|
||||
self.gradients = gradients
|
||||
clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
|
||||
|
||||
# Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See:
|
||||
# https://github.com/tensorflow/tensorflow/issues/1122
|
||||
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
|
||||
self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables),
|
||||
global_step=global_step)
|
||||
|
||||
def get_dummy_feed_dict(self):
|
||||
feed_dict = {
|
||||
self.is_manual_attention: False,
|
||||
self.manual_alignments: np.zeros([1, 1, 1]),
|
||||
}
|
||||
return feed_dict
|
181
recognition/alignment.py
Normal file
181
recognition/alignment.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
import os
|
||||
import string
|
||||
import argparse
|
||||
import operator
|
||||
from functools import partial
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from audio.get_duration import get_durations
|
||||
from text import remove_puncuations, text_to_sequence
|
||||
from utils import load_json, write_json, parallel_run, remove_postfix, backup_file
|
||||
|
||||
def plain_text(text):
|
||||
return "".join(remove_puncuations(text.strip()).split())
|
||||
|
||||
def add_punctuation(text):
|
||||
if text.endswith('다'):
|
||||
return text + "."
|
||||
else:
|
||||
return text
|
||||
|
||||
def similarity(text_a, text_b):
|
||||
text_a = plain_text(text_a)
|
||||
text_b = plain_text(text_b)
|
||||
|
||||
score = SequenceMatcher(None, text_a, text_b).ratio()
|
||||
return score
|
||||
|
||||
def first_word_combined_words(text):
|
||||
words = text.split()
|
||||
if len(words) > 1:
|
||||
first_words = [words[0], words[0]+words[1]]
|
||||
else:
|
||||
first_words = [words[0]]
|
||||
return first_words
|
||||
|
||||
def first_word_combined_texts(text):
|
||||
words = text.split()
|
||||
if len(words) > 1:
|
||||
if len(words) > 2:
|
||||
text2 = " ".join([words[0]+words[1]] + words[2:])
|
||||
else:
|
||||
text2 = words[0]+words[1]
|
||||
texts = [text, text2]
|
||||
else:
|
||||
texts = [text]
|
||||
return texts
|
||||
|
||||
def search_optimal(found_text, recognition_text):
|
||||
# 1. found_text is usually more accurate
|
||||
# 2. recognition_text can have more or less word
|
||||
|
||||
optimal = None
|
||||
|
||||
if plain_text(recognition_text) in plain_text(found_text):
|
||||
optimal = recognition_text
|
||||
else:
|
||||
found = False
|
||||
|
||||
for tmp_text in first_word_combined_texts(found_text):
|
||||
for recognition_first_word in first_word_combined_words(recognition_text):
|
||||
if recognition_first_word in tmp_text:
|
||||
start_idx = tmp_text.find(recognition_first_word)
|
||||
|
||||
if tmp_text != found_text:
|
||||
found_text = found_text[max(0, start_idx-1):].strip()
|
||||
else:
|
||||
found_text = found_text[start_idx:].strip()
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
break
|
||||
|
||||
recognition_last_word = recognition_text.split()[-1]
|
||||
if recognition_last_word in found_text:
|
||||
end_idx = found_text.find(recognition_last_word)
|
||||
|
||||
punctuation = ""
|
||||
if len(found_text) > end_idx + len(recognition_last_word):
|
||||
punctuation = found_text[end_idx + len(recognition_last_word)]
|
||||
if punctuation not in string.punctuation:
|
||||
punctuation = ""
|
||||
|
||||
found_text = found_text[:end_idx] + recognition_last_word + punctuation
|
||||
found = True
|
||||
|
||||
if found:
|
||||
optimal = found_text
|
||||
|
||||
return optimal
|
||||
|
||||
|
||||
def align_text_for_jtbc(
|
||||
item, score_threshold, debug=False):
|
||||
|
||||
audio_path, recognition_text = item
|
||||
|
||||
audio_dir = os.path.dirname(audio_path)
|
||||
base_dir = os.path.dirname(audio_dir)
|
||||
|
||||
news_path = remove_postfix(audio_path.replace("audio", "assets"))
|
||||
news_path = os.path.splitext(news_path)[0] + ".txt"
|
||||
|
||||
strip_fn = lambda line: line.strip().replace('"', '').replace("'", "")
|
||||
candidates = [strip_fn(line) for line in open(news_path).readlines()]
|
||||
|
||||
scores = { candidate: similarity(candidate, recognition_text) \
|
||||
for candidate in candidates}
|
||||
sorted_scores = sorted(scores.items(), key=operator.itemgetter(1))[::-1]
|
||||
|
||||
first, second = sorted_scores[0], sorted_scores[1]
|
||||
|
||||
if first[1] > second[1] and first[1] >= score_threshold:
|
||||
found_text, score = first
|
||||
aligned_text = search_optimal(found_text, recognition_text)
|
||||
|
||||
if debug:
|
||||
print(" ", audio_path)
|
||||
print(" ", recognition_text)
|
||||
print("=> ", found_text)
|
||||
print("==>", aligned_text)
|
||||
print("="*30)
|
||||
|
||||
if aligned_text is not None:
|
||||
result = { audio_path: add_punctuation(aligned_text) }
|
||||
elif abs(len(text_to_sequence(found_text)) - len(text_to_sequence(recognition_text))) > 10:
|
||||
result = {}
|
||||
else:
|
||||
result = { audio_path: [add_punctuation(found_text), recognition_text] }
|
||||
else:
|
||||
result = {}
|
||||
|
||||
if len(result) == 0:
|
||||
result = { audio_path: [recognition_text] }
|
||||
|
||||
return result
|
||||
|
||||
def align_text_batch(config):
|
||||
if "jtbc" in config.recognition_path.lower():
|
||||
align_text = partial(align_text_for_jtbc,
|
||||
score_threshold=config.score_threshold)
|
||||
else:
|
||||
raise Exception(" [!] find_related_texts for `{}` is not defined". \
|
||||
format(config.recognition_path))
|
||||
|
||||
results = {}
|
||||
data = load_json(config.recognition_path)
|
||||
|
||||
items = parallel_run(
|
||||
align_text, data.items(),
|
||||
desc="align_text_batch", parallel=True)
|
||||
|
||||
for item in items:
|
||||
results.update(item)
|
||||
|
||||
found_count = sum([type(value) == str for value in results.values()])
|
||||
print(" [*] # found: {:.5f}% ({}/{})".format(
|
||||
len(results)/len(data), len(results), len(data)))
|
||||
print(" [*] # exact match: {:.5f}% ({}/{})".format(
|
||||
found_count/len(items), found_count, len(items)))
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--recognition_path', required=True)
|
||||
parser.add_argument('--alignment_filename', default="alignment.json")
|
||||
parser.add_argument('--score_threshold', default=0.4, type=float)
|
||||
config, unparsed = parser.parse_known_args()
|
||||
|
||||
results = align_text_batch(config)
|
||||
|
||||
base_dir = os.path.dirname(config.recognition_path)
|
||||
alignment_path = \
|
||||
os.path.join(base_dir, config.alignment_filename)
|
||||
|
||||
if os.path.exists(alignment_path):
|
||||
backup_file(alignment_path)
|
||||
|
||||
write_json(alignment_path, results)
|
||||
duration = get_durations(results.keys(), print_detail=False)
|
126
recognition/google.py
Normal file
126
recognition/google.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
import io
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from functools import partial
|
||||
|
||||
from utils import parallel_run, remove_file, backup_file, write_json
|
||||
from audio import load_audio, save_audio, resample_audio, get_duration
|
||||
|
||||
|
||||
def text_recognition(path, config):
|
||||
root, ext = os.path.splitext(path)
|
||||
txt_path = root + ".txt"
|
||||
|
||||
if os.path.exists(txt_path):
|
||||
with open(txt_path) as f:
|
||||
out = json.loads(open(txt_path).read())
|
||||
return out
|
||||
|
||||
from google.cloud import speech
|
||||
from google.cloud.speech import enums
|
||||
from google.cloud.speech import types
|
||||
|
||||
out = {}
|
||||
error_count = 0
|
||||
|
||||
tmp_path = os.path.splitext(path)[0] + ".tmp.wav"
|
||||
|
||||
while True:
|
||||
try:
|
||||
client = speech.SpeechClient()
|
||||
|
||||
content = load_audio(
|
||||
path, pre_silence_length=config.pre_silence_length,
|
||||
post_silence_length=config.post_silence_length)
|
||||
|
||||
max_duration = config.max_duration - \
|
||||
config.pre_silence_length - config.post_silence_length
|
||||
audio_duration = get_duration(content)
|
||||
|
||||
if audio_duration >= max_duration:
|
||||
print(" [!] Skip {} because of duration: {} > {}". \
|
||||
format(path, audio_duration, max_duration))
|
||||
return {}
|
||||
|
||||
content = resample_audio(content, config.sample_rate)
|
||||
save_audio(content, tmp_path, config.sample_rate)
|
||||
|
||||
with io.open(tmp_path, 'rb') as f:
|
||||
audio = types.RecognitionAudio(content=f.read())
|
||||
|
||||
config = types.RecognitionConfig(
|
||||
encoding=enums.RecognitionConfig.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=config.sample_rate,
|
||||
language_code='ko-KR')
|
||||
|
||||
response = client.recognize(config, audio)
|
||||
if len(response.results) > 0:
|
||||
alternatives = response.results[0].alternatives
|
||||
|
||||
results = [alternative.transcript for alternative in alternatives]
|
||||
assert len(results) == 1, "More than 1 results: {}".format(results)
|
||||
|
||||
out = { path: "" if len(results) == 0 else results[0] }
|
||||
print(path, results[0])
|
||||
break
|
||||
break
|
||||
except Exception as err:
|
||||
raise Exception("OS error: {0}".format(err))
|
||||
|
||||
error_count += 1
|
||||
print("Skip warning for {} for {} times". \
|
||||
format(path, error_count))
|
||||
|
||||
if error_count > 5:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
remove_file(tmp_path)
|
||||
with open(txt_path, 'w') as f:
|
||||
json.dump(out, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return out
|
||||
|
||||
def text_recognition_batch(paths, config):
|
||||
paths.sort()
|
||||
|
||||
results = {}
|
||||
items = parallel_run(
|
||||
partial(text_recognition, config=config), paths,
|
||||
desc="text_recognition_batch", parallel=True)
|
||||
for item in items:
|
||||
results.update(item)
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--audio_pattern', required=True)
|
||||
parser.add_argument('--recognition_filename', default="recognition.json")
|
||||
parser.add_argument('--sample_rate', default=16000, type=int)
|
||||
parser.add_argument('--pre_silence_length', default=1, type=int)
|
||||
parser.add_argument('--post_silence_length', default=1, type=int)
|
||||
parser.add_argument('--max_duration', default=60, type=int)
|
||||
config, unparsed = parser.parse_known_args()
|
||||
|
||||
audio_dir = os.path.dirname(config.audio_pattern)
|
||||
|
||||
for tmp_path in glob(os.path.join(audio_dir, "*.tmp.*")):
|
||||
remove_file(tmp_path)
|
||||
|
||||
paths = glob(config.audio_pattern)
|
||||
paths.sort()
|
||||
results = text_recognition_batch(paths, config)
|
||||
|
||||
base_dir = os.path.dirname(audio_dir)
|
||||
recognition_path = \
|
||||
os.path.join(base_dir, config.recognition_filename)
|
||||
|
||||
if os.path.exists(recognition_path):
|
||||
backup_file(recognition_path)
|
||||
|
||||
write_json(recognition_path, results)
|
75
requirements.txt
Normal file
75
requirements.txt
Normal file
|
@ -0,0 +1,75 @@
|
|||
appnope==0.1.0
|
||||
audioread==2.1.5
|
||||
bleach==1.5.0
|
||||
certifi==2017.7.27.1
|
||||
chardet==3.0.4
|
||||
click==6.7
|
||||
cycler==0.10.0
|
||||
Cython==0.26.1
|
||||
decorator==4.0.11
|
||||
entrypoints==0.2.3
|
||||
Flask==0.12.2
|
||||
Flask-Cors==3.0.3
|
||||
gTTS==1.2.2
|
||||
gTTS-token==1.1.1
|
||||
html5lib==0.9999999
|
||||
idna==2.6
|
||||
imageio==2.1.2
|
||||
ipdb==0.10.3
|
||||
ipykernel==4.6.1
|
||||
ipython==6.1.0
|
||||
ipython-genutils==0.2.0
|
||||
ipywidgets==7.0.1
|
||||
itsdangerous==0.24
|
||||
jamo==0.4.0
|
||||
jedi==0.10.2
|
||||
Jinja2==2.9.6
|
||||
joblib==0.11
|
||||
jsonschema==2.6.0
|
||||
jupyter-client==5.1.0
|
||||
jupyter-core==4.3.0
|
||||
librosa==0.5.1
|
||||
llvmlite==0.20.0
|
||||
Markdown==2.6.9
|
||||
MarkupSafe==1.0
|
||||
matplotlib==2.0.2
|
||||
mistune==0.7.4
|
||||
moviepy==0.2.3.2
|
||||
nbconvert==5.3.1
|
||||
nbformat==4.4.0
|
||||
nltk==3.2.4
|
||||
notebook==5.1.0
|
||||
numba==0.35.0
|
||||
numpy==1.13.3
|
||||
olefile==0.44
|
||||
pandocfilters==1.4.2
|
||||
pexpect==4.2.1
|
||||
pickleshare==0.7.4
|
||||
Pillow==4.3.0
|
||||
prompt-toolkit==1.0.15
|
||||
protobuf==3.4.0
|
||||
ptyprocess==0.5.2
|
||||
pydub==0.20.0
|
||||
Pygments==2.2.0
|
||||
pyparsing==2.2.0
|
||||
python-dateutil==2.6.1
|
||||
pytz==2017.2
|
||||
pyzmq==16.0.2
|
||||
requests==2.18.4
|
||||
resampy==0.2.0
|
||||
scikit-learn==0.19.0
|
||||
scipy==0.19.1
|
||||
simplegeneric==0.8.1
|
||||
six==1.11.0
|
||||
tensorflow==1.3.0
|
||||
tensorflow-tensorboard==0.1.6
|
||||
terminado==0.6
|
||||
testpath==0.3.1
|
||||
tornado==4.5.2
|
||||
tqdm==4.11.2
|
||||
traitlets==4.3.2
|
||||
urllib3==1.22
|
||||
wcwidth==0.1.7
|
||||
Werkzeug==0.12.2
|
||||
widgetsnbextension==3.0.3
|
||||
youtube-dl==2017.10.7
|
8
run.sh
Normal file
8
run.sh
Normal file
|
@ -0,0 +1,8 @@
|
|||
#!/bin/sh
|
||||
|
||||
CUDA_VISIBLE_DEVICES= python app.py --load_path logs/deepvoice2-256-256-krbook-bah-mon-22000-no-priority --dataname=krbook --num_speakers=1
|
||||
CUDA_VISIBLE_DEVICES= python app.py --load_path logs/jtbc_2017-09-25_11-49-23 --dataname=krbook --num_speakers=1 --port=5002
|
||||
CUDA_VISIBLE_DEVICES= python app.py --load_path logs/krbook_2017-09-27_17-02-44 --dataname=krbook --num_speakers=1 --port=5001
|
||||
CUDA_VISIBLE_DEVICES= python app.py --load_path logs/krfemale_2017-10-10_20-37-38 --dataname=krbook --num_speakers=1 --port=5003
|
||||
CUDA_VISIBLE_DEVICES= python app.py --load_path logs/krmale_2017-10-10_17-49-49 --dataname=krbook --num_speakers=1 --port=5005
|
||||
CUDA_VISIBLE_DEVICES= python app.py --load_path logs/park+moon+krbook_2017-10-09_20-43-53 --dataname=krbook --num_speakers=3 --port=5004
|
16
scripts/prepare_jtbc.sh
Executable file
16
scripts/prepare_jtbc.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/sh
|
||||
|
||||
# 1. Download and extract audio and texts
|
||||
python -m datasets.jtbc.download
|
||||
|
||||
# 2. Split audios on silence
|
||||
python -m audio.silence --audio_pattern "./datasets/jtbc/audio/*.wav" --method=pydub
|
||||
|
||||
# 3. Run Google Speech Recognition
|
||||
python -m recognition.google --audio_pattern "./datasets/jtbc/audio/*.*.wav"
|
||||
|
||||
# 4. Run heuristic text-audio pair search (any improvement on this is welcome)
|
||||
python -m recognition.alignment --recognition_path "./datasets/jtbc/recognition.json" --score_threshold=0.5
|
||||
|
||||
# 5. Remove intro music
|
||||
rm datasets/jtbc/data/*.0000.npz
|
13
scripts/prepare_moon.sh
Executable file
13
scripts/prepare_moon.sh
Executable file
|
@ -0,0 +1,13 @@
|
|||
#!/bin/sh
|
||||
|
||||
# 1. Download and extract audio and texts
|
||||
python -m datasets.moon.download
|
||||
|
||||
# 2. Split audios on silence
|
||||
python -m audio.silence --audio_pattern "./datasets/moon/audio/*.wav" --method=pydub
|
||||
|
||||
# 3. Run Google Speech Recognition
|
||||
python -m recognition.google --audio_pattern "./datasets/moon/audio/*.*.wav"
|
||||
|
||||
# 4. Run heuristic text-audio pair search (any improvement on this is welcome)
|
||||
python -m recognition.alignment --recognition_path "./datasets/moon/recognition.json" --score_threshold=0.5
|
13
scripts/prepare_park.sh
Executable file
13
scripts/prepare_park.sh
Executable file
|
@ -0,0 +1,13 @@
|
|||
#!/bin/sh
|
||||
|
||||
# 1. Download and extract audio and texts
|
||||
python -m datasets.park.download
|
||||
|
||||
# 2. Split audios on silence
|
||||
python -m audio.silence --audio_pattern "./datasets/park/audio/*.wav" --method=pydub
|
||||
|
||||
# 3. Run Google Speech Recognition
|
||||
python -m recognition.google --audio_pattern "./datasets/park/audio/*.*.wav"
|
||||
|
||||
# 4. Run heuristic text-audio pair search (any improvement on this is welcome)
|
||||
python -m recognition.alignment --recognition_path "./datasets/park/recognition.json" --score_threshold=0.5
|
389
synthesizer.py
Normal file
389
synthesizer.py
Normal file
|
@ -0,0 +1,389 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import librosa
|
||||
import argparse
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
import tensorflow as tf
|
||||
from functools import partial
|
||||
|
||||
from hparams import hparams
|
||||
from models import create_model, get_most_recent_checkpoint
|
||||
from audio import save_audio, inv_spectrogram, inv_preemphasis, \
|
||||
inv_spectrogram_tensorflow
|
||||
from utils import plot, PARAMS_NAME, load_json, load_hparams, \
|
||||
add_prefix, add_postfix, get_time, parallel_run, makedirs
|
||||
|
||||
from text.korean import tokenize
|
||||
from text import text_to_sequence, sequence_to_text
|
||||
|
||||
|
||||
class Synthesizer(object):
|
||||
def close(self):
|
||||
tf.reset_default_graph()
|
||||
self.sess.close()
|
||||
|
||||
def load(self, checkpoint_path, num_speakers=2, checkpoint_step=None, model_name='tacotron'):
|
||||
self.num_speakers = num_speakers
|
||||
|
||||
if os.path.isdir(checkpoint_path):
|
||||
load_path = checkpoint_path
|
||||
checkpoint_path = get_most_recent_checkpoint(checkpoint_path, checkpoint_step)
|
||||
else:
|
||||
load_path = os.path.dirname(checkpoint_path)
|
||||
|
||||
print('Constructing model: %s' % model_name)
|
||||
|
||||
inputs = tf.placeholder(tf.int32, [None, None], 'inputs')
|
||||
input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
|
||||
|
||||
batch_size = tf.shape(inputs)[0]
|
||||
speaker_id = tf.placeholder_with_default(
|
||||
tf.zeros([batch_size], dtype=tf.int32), [None], 'speaker_id')
|
||||
|
||||
load_hparams(hparams, load_path)
|
||||
with tf.variable_scope('model') as scope:
|
||||
self.model = create_model(hparams)
|
||||
|
||||
self.model.initialize(
|
||||
inputs, input_lengths,
|
||||
self.num_speakers, speaker_id)
|
||||
self.wav_output = \
|
||||
inv_spectrogram_tensorflow(self.model.linear_outputs)
|
||||
|
||||
print('Loading checkpoint: %s' % checkpoint_path)
|
||||
|
||||
sess_config = tf.ConfigProto(
|
||||
allow_soft_placement=True,
|
||||
intra_op_parallelism_threads=1,
|
||||
inter_op_parallelism_threads=2)
|
||||
sess_config.gpu_options.allow_growth = True
|
||||
|
||||
self.sess = tf.Session(config=sess_config)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(self.sess, checkpoint_path)
|
||||
|
||||
def synthesize(self,
|
||||
texts=None, tokens=None,
|
||||
base_path=None, paths=None, speaker_ids=None,
|
||||
start_of_sentence=None, end_of_sentence=True,
|
||||
pre_word_num=0, post_word_num=0,
|
||||
pre_surplus_idx=0, post_surplus_idx=1,
|
||||
use_short_concat=False,
|
||||
manual_attention_mode=0,
|
||||
base_alignment_path=None,
|
||||
librosa_trim=False,
|
||||
attention_trim=True):
|
||||
|
||||
# Possible inputs:
|
||||
# 1) text=text
|
||||
# 2) text=texts
|
||||
# 3) tokens=tokens, texts=texts # use texts as guide
|
||||
|
||||
if type(texts) == str:
|
||||
texts = [texts]
|
||||
|
||||
if texts is not None and tokens is None:
|
||||
sequences = [text_to_sequence(text) for text in texts]
|
||||
elif tokens is not None:
|
||||
sequences = tokens
|
||||
|
||||
if paths is None:
|
||||
paths = [None] * len(sequences)
|
||||
if texts is None:
|
||||
texts = [None] * len(sequences)
|
||||
|
||||
time_str = get_time()
|
||||
def plot_and_save_parallel(
|
||||
wavs, alignments, use_manual_attention):
|
||||
|
||||
items = list(enumerate(zip(
|
||||
wavs, alignments, paths, texts, sequences)))
|
||||
|
||||
fn = partial(
|
||||
plot_graph_and_save_audio,
|
||||
base_path=base_path,
|
||||
start_of_sentence=start_of_sentence, end_of_sentence=end_of_sentence,
|
||||
pre_word_num=pre_word_num, post_word_num=post_word_num,
|
||||
pre_surplus_idx=pre_surplus_idx, post_surplus_idx=post_surplus_idx,
|
||||
use_short_concat=use_short_concat,
|
||||
use_manual_attention=use_manual_attention,
|
||||
librosa_trim=librosa_trim,
|
||||
attention_trim=attention_trim,
|
||||
time_str=time_str)
|
||||
return parallel_run(fn, items,
|
||||
desc="plot_graph_and_save_audio", parallel=False)
|
||||
|
||||
input_lengths = np.argmax(np.array(sequences) == 1, 1)
|
||||
|
||||
fetches = [
|
||||
#self.wav_output,
|
||||
self.model.linear_outputs,
|
||||
self.model.alignments,
|
||||
]
|
||||
|
||||
feed_dict = {
|
||||
self.model.inputs: sequences,
|
||||
self.model.input_lengths: input_lengths,
|
||||
}
|
||||
if base_alignment_path is None:
|
||||
feed_dict.update({
|
||||
self.model.manual_alignments: np.zeros([1, 1, 1]),
|
||||
self.model.is_manual_attention: False,
|
||||
})
|
||||
else:
|
||||
manual_alignments = []
|
||||
alignment_path = os.path.join(
|
||||
base_alignment_path,
|
||||
os.path.basename(base_path))
|
||||
|
||||
for idx in range(len(sequences)):
|
||||
numpy_path = "{}.{}.npy".format(alignment_path, idx)
|
||||
manual_alignments.append(np.load(numpy_path))
|
||||
|
||||
alignments_T = np.transpose(manual_alignments, [0, 2, 1])
|
||||
feed_dict.update({
|
||||
self.model.manual_alignments: alignments_T,
|
||||
self.model.is_manual_attention: True,
|
||||
})
|
||||
|
||||
if speaker_ids is not None:
|
||||
if type(speaker_ids) == dict:
|
||||
speaker_embed_table = sess.run(
|
||||
self.model.speaker_embed_table)
|
||||
|
||||
speaker_embed = [speaker_ids[speaker_id] * \
|
||||
speaker_embed_table[speaker_id] for speaker_id in speaker_ids]
|
||||
feed_dict.update({
|
||||
self.model.speaker_embed_table: np.tile()
|
||||
})
|
||||
else:
|
||||
feed_dict[self.model.speaker_id] = speaker_ids
|
||||
|
||||
wavs, alignments = \
|
||||
self.sess.run(fetches, feed_dict=feed_dict)
|
||||
results = plot_and_save_parallel(
|
||||
wavs, alignments, True)
|
||||
|
||||
if manual_attention_mode > 0:
|
||||
# argmax one hot
|
||||
if manual_attention_mode == 1:
|
||||
alignments_T = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
|
||||
new_alignments = np.zeros_like(alignments_T)
|
||||
|
||||
for idx in range(len(alignments)):
|
||||
argmax = alignments[idx].argmax(1)
|
||||
new_alignments[idx][(argmax, range(len(argmax)))] = 1
|
||||
# sharpening
|
||||
elif manual_attention_mode == 2:
|
||||
new_alignments = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
|
||||
|
||||
for idx in range(len(alignments)):
|
||||
var = np.var(new_alignments[idx], 1)
|
||||
mean_var = var[:input_lengths[idx]].mean()
|
||||
|
||||
new_alignments = np.pow(new_alignments[idx], 2)
|
||||
# prunning
|
||||
elif manual_attention_mode == 3:
|
||||
new_alignments = np.transpose(alignments, [0, 2, 1]) # [N, E, D]
|
||||
|
||||
for idx in range(len(alignments)):
|
||||
argmax = alignments[idx].argmax(1)
|
||||
new_alignments[idx][(argmax, range(len(argmax)))] = 1
|
||||
|
||||
feed_dict.update({
|
||||
self.model.manual_alignments: new_alignments,
|
||||
self.model.is_manual_attention: True,
|
||||
})
|
||||
|
||||
new_wavs, new_alignments = \
|
||||
self.sess.run(fetches, feed_dict=feed_dict)
|
||||
results = plot_and_save_parallel(
|
||||
new_wavs, new_alignments, True)
|
||||
|
||||
return results
|
||||
|
||||
def plot_graph_and_save_audio(args,
|
||||
base_path=None,
|
||||
start_of_sentence=None, end_of_sentence=None,
|
||||
pre_word_num=0, post_word_num=0,
|
||||
pre_surplus_idx=0, post_surplus_idx=1,
|
||||
use_short_concat=False,
|
||||
use_manual_attention=False, save_alignment=False,
|
||||
librosa_trim=False, attention_trim=False,
|
||||
time_str=None):
|
||||
|
||||
idx, (wav, alignment, path, text, sequence) = args
|
||||
|
||||
if base_path:
|
||||
plot_path = "{}/{}.png".format(base_path, get_time())
|
||||
elif path:
|
||||
plot_path = path.rsplit('.', 1)[0] + ".png"
|
||||
else:
|
||||
plot_path = None
|
||||
|
||||
#plot_path = add_prefix(plot_path, time_str)
|
||||
if use_manual_attention:
|
||||
plot_path = add_postfix(plot_path, "manual")
|
||||
|
||||
if plot_path:
|
||||
plot.plot_alignment(alignment, plot_path, text=text)
|
||||
|
||||
if use_short_concat:
|
||||
wav = short_concat(
|
||||
wav, alignment, text,
|
||||
start_of_sentence, end_of_sentence,
|
||||
pre_word_num, post_word_num,
|
||||
pre_surplus_idx, post_surplus_idx)
|
||||
|
||||
if attention_trim and end_of_sentence:
|
||||
end_idx_counter = 0
|
||||
attention_argmax = alignment.argmax(0)
|
||||
end_idx = min(len(sequence) - 1, max(attention_argmax))
|
||||
max_counter = min((attention_argmax == end_idx).sum(), 5)
|
||||
|
||||
for jdx, attend_idx in enumerate(attention_argmax):
|
||||
if len(attention_argmax) > jdx + 1:
|
||||
if attend_idx == end_idx:
|
||||
end_idx_counter += 1
|
||||
|
||||
if attend_idx == end_idx and attention_argmax[jdx + 1] > end_idx:
|
||||
break
|
||||
|
||||
if end_idx_counter >= max_counter:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
spec_end_idx = hparams.reduction_factor * jdx + 3
|
||||
wav = wav[:spec_end_idx]
|
||||
|
||||
audio_out = inv_spectrogram(wav.T)
|
||||
|
||||
if librosa_trim and end_of_sentence:
|
||||
yt, index = librosa.effects.trim(audio_out,
|
||||
frame_length=5120, hop_length=256, top_db=50)
|
||||
audio_out = audio_out[:index[-1]]
|
||||
|
||||
if save_alignment:
|
||||
alignment_path = "{}/{}.npy".format(base_path, idx)
|
||||
np.save(alignment_path, alignment, allow_pickle=False)
|
||||
|
||||
if path or base_path:
|
||||
if path:
|
||||
current_path = add_postfix(path, idx)
|
||||
elif base_path:
|
||||
current_path = plot_path.replace(".png", ".wav")
|
||||
|
||||
save_audio(audio_out, current_path)
|
||||
return True
|
||||
else:
|
||||
io_out = io.BytesIO()
|
||||
save_audio(audio_out, io_out)
|
||||
result = io_out.getvalue()
|
||||
return result
|
||||
|
||||
def get_most_recent_checkpoint(checkpoint_dir, checkpoint_step=None):
|
||||
if checkpoint_step is None:
|
||||
checkpoint_paths = [path for path in glob("{}/*.ckpt-*.data-*".format(checkpoint_dir))]
|
||||
idxes = [int(os.path.basename(path).split('-')[1].split('.')[0]) for path in checkpoint_paths]
|
||||
|
||||
max_idx = max(idxes)
|
||||
else:
|
||||
max_idx = checkpoint_step
|
||||
lastest_checkpoint = os.path.join(checkpoint_dir, "model.ckpt-{}".format(max_idx))
|
||||
print(" [*] Found lastest checkpoint: {}".format(lastest_checkpoint))
|
||||
return lastest_checkpoint
|
||||
|
||||
def short_concat(
|
||||
wav, alignment, text,
|
||||
start_of_sentence, end_of_sentence,
|
||||
pre_word_num, post_word_num,
|
||||
pre_surplus_idx, post_surplus_idx):
|
||||
|
||||
# np.array(list(decomposed_text))[attention_argmax]
|
||||
attention_argmax = alignment.argmax(0)
|
||||
|
||||
if not start_of_sentence and pre_word_num > 0:
|
||||
surplus_decomposed_text = decompose_ko_text("".join(text.split()[0]))
|
||||
start_idx = len(surplus_decomposed_text) + 1
|
||||
|
||||
for idx, attend_idx in enumerate(attention_argmax):
|
||||
if attend_idx == start_idx and attention_argmax[idx - 1] < start_idx:
|
||||
break
|
||||
|
||||
wav_start_idx = hparams.reduction_factor * idx - 1 - pre_surplus_idx
|
||||
else:
|
||||
wav_start_idx = 0
|
||||
|
||||
if not end_of_sentence and post_word_num > 0:
|
||||
surplus_decomposed_text = decompose_ko_text("".join(text.split()[-1]))
|
||||
end_idx = len(decomposed_text.replace(surplus_decomposed_text, '')) - 1
|
||||
|
||||
for idx, attend_idx in enumerate(attention_argmax):
|
||||
if attend_idx == end_idx and attention_argmax[idx + 1] > end_idx:
|
||||
break
|
||||
|
||||
wav_end_idx = hparams.reduction_factor * idx + 1 + post_surplus_idx
|
||||
else:
|
||||
if True: # attention based split
|
||||
if end_of_sentence:
|
||||
end_idx = min(len(decomposed_text) - 1, max(attention_argmax))
|
||||
else:
|
||||
surplus_decomposed_text = decompose_ko_text("".join(text.split()[-1]))
|
||||
end_idx = len(decomposed_text.replace(surplus_decomposed_text, '')) - 1
|
||||
|
||||
while True:
|
||||
if end_idx in attention_argmax:
|
||||
break
|
||||
end_idx -= 1
|
||||
|
||||
end_idx_counter = 0
|
||||
for idx, attend_idx in enumerate(attention_argmax):
|
||||
if len(attention_argmax) > idx + 1:
|
||||
if attend_idx == end_idx:
|
||||
end_idx_counter += 1
|
||||
|
||||
if attend_idx == end_idx and attention_argmax[idx + 1] > end_idx:
|
||||
break
|
||||
|
||||
if end_idx_counter > 5:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
wav_end_idx = hparams.reduction_factor * idx + 1 + post_surplus_idx
|
||||
else:
|
||||
wav_end_idx = None
|
||||
|
||||
wav = wav[wav_start_idx:wav_end_idx]
|
||||
|
||||
if end_of_sentence:
|
||||
wav = np.lib.pad(wav, ((0, 20), (0, 0)), 'constant', constant_values=0)
|
||||
else:
|
||||
wav = np.lib.pad(wav, ((0, 10), (0, 0)), 'constant', constant_values=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--load_path', required=True)
|
||||
parser.add_argument('--sample_path', default="samples")
|
||||
parser.add_argument('--text', required=True)
|
||||
parser.add_argument('--num_speakers', default=1, type=int)
|
||||
parser.add_argument('--speaker_id', default=0, type=int)
|
||||
parser.add_argument('--checkpoint_step', default=None, type=int)
|
||||
config = parser.parse_args()
|
||||
|
||||
makedirs(config.sample_path)
|
||||
|
||||
synthesizer = Synthesizer()
|
||||
synthesizer.load(config.load_path, config.num_speakers, config.checkpoint_step)
|
||||
|
||||
audio = synthesizer.synthesize(
|
||||
texts=[config.text],
|
||||
base_path=config.sample_path,
|
||||
speaker_ids=[config.speaker_id],
|
||||
attention_trim=False)[0]
|
101
text/__init__.py
Normal file
101
text/__init__.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import re
|
||||
import string
|
||||
import numpy as np
|
||||
|
||||
from text import cleaners
|
||||
from hparams import hparams
|
||||
from text.symbols import symbols, PAD, EOS
|
||||
from text.korean import jamo_to_korean
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
||||
puncuation_table = str.maketrans({key: None for key in string.punctuation})
|
||||
|
||||
def remove_puncuations(text):
|
||||
return text.translate(puncuation_table)
|
||||
|
||||
|
||||
def text_to_sequence(text, as_token=False):
|
||||
cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
|
||||
return _text_to_sequence(text, cleaner_names, as_token)
|
||||
|
||||
def _text_to_sequence(text, cleaner_names, as_token):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
# Append EOS token
|
||||
sequence.append(_symbol_to_id[EOS])
|
||||
|
||||
if as_token:
|
||||
return sequence_to_text(sequence, combine_jamo=True)
|
||||
else:
|
||||
return np.array(sequence, dtype=np.int32)
|
||||
|
||||
|
||||
def sequence_to_text(sequence, skip_eos_and_pad=False, combine_jamo=False):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
|
||||
if not skip_eos_and_pad or s not in [EOS, PAD]:
|
||||
result += s
|
||||
|
||||
result = result.replace('}{', ' ')
|
||||
|
||||
if combine_jamo:
|
||||
return jamo_to_korean(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _symbols_to_sequence(symbols):
|
||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s is not '_' and s is not '~'
|
80
text/cleaners.py
Normal file
80
text/cleaners.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
'''
|
||||
Cleaners are transformations that run over the input text at both training and eval time.
|
||||
|
||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||
1. "english_cleaners" for English text
|
||||
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
'''
|
||||
|
||||
import re
|
||||
from .korean import tokenize as ko_tokenize
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
|
||||
def korean_cleaners(text):
|
||||
'''Pipeline for Korean text, including number and abbreviation expansion.'''
|
||||
text = ko_tokenize(text)
|
||||
return text
|
||||
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
69
text/en_numbers.py
Normal file
69
text/en_numbers.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import inflect
|
||||
import re
|
||||
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
||||
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
_number_re = re.compile(r'[0-9]+')
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace('.', ' point ')
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars' # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
69
text/english.py
Normal file
69
text/english.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Code from https://github.com/keithito/tacotron/blob/master/util/numbers.py
|
||||
import inflect
|
||||
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
||||
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
_number_re = re.compile(r'[0-9]+')
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace('.', ' point ')
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars' # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
|
||||
|
||||
def normalize(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
172
text/ko_dictionary.py
Normal file
172
text/ko_dictionary.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
etc_dictionary = {
|
||||
'2 30대': '이삼십대',
|
||||
'20~30대': '이삼십대',
|
||||
'20, 30대': '이십대 삼십대',
|
||||
'1+1': '원플러스원',
|
||||
'3에서 6개월인': '3개월에서 육개월인',
|
||||
}
|
||||
|
||||
english_dictionary = {
|
||||
'Devsisters': '데브시스터즈',
|
||||
'track': '트랙',
|
||||
|
||||
# krbook
|
||||
'LA': '엘에이',
|
||||
'LG': '엘지',
|
||||
'KOREA': '코리아',
|
||||
'JSA': '제이에스에이',
|
||||
'PGA': '피지에이',
|
||||
'GA': '지에이',
|
||||
'idol': '아이돌',
|
||||
'KTX': '케이티엑스',
|
||||
'AC': '에이씨',
|
||||
'DVD': '디비디',
|
||||
'US': '유에스',
|
||||
'CNN': '씨엔엔',
|
||||
'LPGA': '엘피지에이',
|
||||
'P': '피',
|
||||
'L': '엘',
|
||||
'T': '티',
|
||||
'B': '비',
|
||||
'C': '씨',
|
||||
'BIFF': '비아이에프에프',
|
||||
'GV': '지비',
|
||||
|
||||
# JTBC
|
||||
'IT': '아이티',
|
||||
'IQ': '아이큐',
|
||||
'JTBC': '제이티비씨',
|
||||
'trickle down effect': '트리클 다운 이펙트',
|
||||
'trickle up effect': '트리클 업 이펙트',
|
||||
'down': '다운',
|
||||
'up': '업',
|
||||
'FCK': '에프씨케이',
|
||||
'AP': '에이피',
|
||||
'WHERETHEWILDTHINGSARE': '',
|
||||
'Rashomon Effect': '',
|
||||
'O': '오',
|
||||
'OO': '오오',
|
||||
'B': '비',
|
||||
'GDP': '지디피',
|
||||
'CIPA': '씨아이피에이',
|
||||
'YS': '와이에스',
|
||||
'Y': '와이',
|
||||
'S': '에스',
|
||||
'JTBC': '제이티비씨',
|
||||
'PC': '피씨',
|
||||
'bill': '빌',
|
||||
'Halmuny': '하모니', #####
|
||||
'X': '엑스',
|
||||
'SNS': '에스엔에스',
|
||||
'ability': '어빌리티',
|
||||
'shy': '',
|
||||
'CCTV': '씨씨티비',
|
||||
'IT': '아이티',
|
||||
'the tenth man': '더 텐쓰 맨', ####
|
||||
'L': '엘',
|
||||
'PC': '피씨',
|
||||
'YSDJJPMB': '', ########
|
||||
'Content Attitude Timing': '컨텐트 애티튜드 타이밍',
|
||||
'CAT': '캣',
|
||||
'IS': '아이에스',
|
||||
'SNS': '에스엔에스',
|
||||
'K': '케이',
|
||||
'Y': '와이',
|
||||
'KDI': '케이디아이',
|
||||
'DOC': '디오씨',
|
||||
'CIA': '씨아이에이',
|
||||
'PBS': '피비에스',
|
||||
'D': '디',
|
||||
'PPropertyPositionPowerPrisonP'
|
||||
'S': '에스',
|
||||
'francisco': '프란시스코',
|
||||
'I': '아이',
|
||||
'III': '아이아이', ######
|
||||
'No joke': '노 조크',
|
||||
'BBK': '비비케이',
|
||||
'LA': '엘에이',
|
||||
'Don': '',
|
||||
't worry be happy': ' 워리 비 해피',
|
||||
'NO': '엔오', #####
|
||||
'it was our sky': '잇 워즈 아워 스카이',
|
||||
'it is our sky': '잇 이즈 아워 스카이', ####
|
||||
'NEIS': '엔이아이에스', #####
|
||||
'IMF': '아이엠에프',
|
||||
'apology': '어폴로지',
|
||||
'humble': '험블',
|
||||
'M': '엠',
|
||||
'Nowhere Man': '노웨어 맨',
|
||||
'The Tenth Man': '더 텐쓰 맨',
|
||||
'PBS': '피비에스',
|
||||
'BBC': '비비씨',
|
||||
'MRJ': '엠알제이',
|
||||
'CCTV': '씨씨티비',
|
||||
'Pick me up': '픽 미 업',
|
||||
'DNA': '디엔에이',
|
||||
'UN': '유엔',
|
||||
'STOP': '스탑', #####
|
||||
'PRESS': '프레스', #####
|
||||
'not to be': '낫 투비',
|
||||
'Denial': '디나이얼',
|
||||
'G': '지',
|
||||
'IMF': '아이엠에프',
|
||||
'GDP': '지디피',
|
||||
'JTBC': '제이티비씨',
|
||||
'Time flies like an arrow': '타임 플라이즈 라이크 언 애로우',
|
||||
'DDT': '디디티',
|
||||
'AI': '에이아이',
|
||||
'Z': '제트',
|
||||
'OECD': '오이씨디',
|
||||
'N': '앤',
|
||||
'A': '에이',
|
||||
'MB': '엠비',
|
||||
'EH': '이에이치',
|
||||
'IS': '아이에스',
|
||||
'TV': '티비',
|
||||
'MIT': '엠아이티',
|
||||
'KBO': '케이비오',
|
||||
'I love America': '아이 러브 아메리카',
|
||||
'SF': '에스에프',
|
||||
'Q': '큐',
|
||||
'KFX': '케이에프엑스',
|
||||
'PM': '피엠',
|
||||
'Prime Minister': '프라임 미니스터',
|
||||
'Swordline': '스워드라인',
|
||||
'TBS': '티비에스',
|
||||
'DDT': '디디티',
|
||||
'CS': '씨에스',
|
||||
'Reflecting Absence': '리플렉팅 앱센스',
|
||||
'PBS': '피비에스',
|
||||
'Drum being beaten by everyone': '드럼 빙 비튼 바이 에브리원',
|
||||
'negative pressure': '네거티브 프레셔',
|
||||
'F': '에프',
|
||||
'KIA': '기아',
|
||||
'FTA': '에프티에이',
|
||||
'Que sais-je': '',
|
||||
'UFC': '유에프씨',
|
||||
'P': '피',
|
||||
'DJ': '디제이',
|
||||
'Chaebol': '채벌',
|
||||
'BBC': '비비씨',
|
||||
'OECD': '오이씨디',
|
||||
'BC': '삐씨',
|
||||
'C': '씨',
|
||||
'B': '씨',
|
||||
'KY': '케이와이',
|
||||
'K': '케이',
|
||||
'CEO': '씨이오',
|
||||
'YH': '와이에치',
|
||||
'IS': '아이에스',
|
||||
'who are you': '후 얼 유',
|
||||
'Y': '와이',
|
||||
'The Devils Advocate': '더 데빌즈 어드보카트',
|
||||
'YS': '와이에스',
|
||||
'so sorry': '쏘 쏘리',
|
||||
'Santa': '산타',
|
||||
'Big Endian': '빅 엔디안',
|
||||
'Small Endian': '스몰 엔디안',
|
||||
'Oh Captain My Captain': '오 캡틴 마이 캡틴',
|
||||
'AIB': '에이아이비',
|
||||
'K': '케이',
|
||||
'PBS': '피비에스',
|
||||
}
|
319
text/korean.py
Normal file
319
text/korean.py
Normal file
|
@ -0,0 +1,319 @@
|
|||
# Code based on
|
||||
|
||||
import re
|
||||
import os
|
||||
import ast
|
||||
import json
|
||||
from jamo import hangul_to_jamo, h2j, j2h
|
||||
|
||||
from .ko_dictionary import english_dictionary, etc_dictionary
|
||||
|
||||
PAD = '_'
|
||||
EOS = '~'
|
||||
PUNC = '!\'(),-.:;?'
|
||||
SPACE = ' '
|
||||
|
||||
JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
|
||||
JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
|
||||
JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])
|
||||
|
||||
VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + PUNC + SPACE
|
||||
ALL_SYMBOLS = PAD + EOS + VALID_CHARS
|
||||
|
||||
char_to_id = {c: i for i, c in enumerate(ALL_SYMBOLS)}
|
||||
id_to_char = {i: c for i, c in enumerate(ALL_SYMBOLS)}
|
||||
|
||||
quote_checker = """([`"'"“‘])(.+?)([`"'"”’])"""
|
||||
|
||||
def is_lead(char):
|
||||
return char in JAMO_LEADS
|
||||
|
||||
def is_vowel(char):
|
||||
return char in JAMO_VOWELS
|
||||
|
||||
def is_tail(char):
|
||||
return char in JAMO_TAILS
|
||||
|
||||
def get_mode(char):
|
||||
if is_lead(char):
|
||||
return 0
|
||||
elif is_vowel(char):
|
||||
return 1
|
||||
elif is_tail(char):
|
||||
return 2
|
||||
else:
|
||||
return -1
|
||||
|
||||
def _get_text_from_candidates(candidates):
|
||||
if len(candidates) == 0:
|
||||
return ""
|
||||
elif len(candidates) == 1:
|
||||
return _jamo_char_to_hcj(candidates[0])
|
||||
else:
|
||||
return j2h(**dict(zip(["lead", "vowel", "tail"], candidates)))
|
||||
|
||||
def jamo_to_korean(text):
|
||||
text = h2j(text)
|
||||
|
||||
idx = 0
|
||||
new_text = ""
|
||||
candidates = []
|
||||
|
||||
while True:
|
||||
if idx >= len(text):
|
||||
new_text += _get_text_from_candidates(candidates)
|
||||
break
|
||||
|
||||
char = text[idx]
|
||||
mode = get_mode(char)
|
||||
|
||||
if mode == 0:
|
||||
new_text += _get_text_from_candidates(candidates)
|
||||
candidates = [char]
|
||||
elif mode == -1:
|
||||
new_text += _get_text_from_candidates(candidates)
|
||||
new_text += char
|
||||
candidates = []
|
||||
else:
|
||||
candidates.append(char)
|
||||
|
||||
idx += 1
|
||||
return new_text
|
||||
|
||||
num_to_kor = {
|
||||
'0': '영',
|
||||
'1': '일',
|
||||
'2': '이',
|
||||
'3': '삼',
|
||||
'4': '사',
|
||||
'5': '오',
|
||||
'6': '육',
|
||||
'7': '칠',
|
||||
'8': '팔',
|
||||
'9': '구',
|
||||
}
|
||||
|
||||
unit_to_kor1 = {
|
||||
'%': '퍼센트',
|
||||
'cm': '센치미터',
|
||||
'mm': '밀리미터',
|
||||
'km': '킬로미터',
|
||||
'kg': '킬로그람',
|
||||
}
|
||||
unit_to_kor2 = {
|
||||
'm': '미터',
|
||||
}
|
||||
|
||||
upper_to_kor = {
|
||||
'A': '에이',
|
||||
'B': '비',
|
||||
'C': '씨',
|
||||
'D': '디',
|
||||
'E': '이',
|
||||
'F': '에프',
|
||||
'G': '지',
|
||||
'H': '에이치',
|
||||
'I': '아이',
|
||||
'J': '제이',
|
||||
'K': '케이',
|
||||
'L': '엘',
|
||||
'M': '엠',
|
||||
'N': '엔',
|
||||
'O': '오',
|
||||
'P': '피',
|
||||
'Q': '큐',
|
||||
'R': '알',
|
||||
'S': '에스',
|
||||
'T': '티',
|
||||
'U': '유',
|
||||
'V': '브이',
|
||||
'W': '더블유',
|
||||
'X': '엑스',
|
||||
'Y': '와이',
|
||||
'Z': '지',
|
||||
}
|
||||
|
||||
def compare_sentence_with_jamo(text1, text2):
|
||||
return h2j(text1) != h2j(text)
|
||||
|
||||
def tokenize(text, as_id=False):
|
||||
text = normalize(text)
|
||||
tokens = list(hangul_to_jamo(text))
|
||||
|
||||
if as_id:
|
||||
return [char_to_id[token] for token in tokens] + [char_to_id[EOS]]
|
||||
else:
|
||||
return [token for token in tokens] + [EOS]
|
||||
|
||||
def tokenizer_fn(iterator):
|
||||
return (token for x in iterator for token in tokenize(x, as_id=False))
|
||||
|
||||
def normalize(text):
|
||||
text = text.strip()
|
||||
|
||||
text = re.sub('\(\d+일\)', '', text)
|
||||
text = re.sub('\([⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]+\)', '', text)
|
||||
|
||||
text = normalize_with_dictionary(text, etc_dictionary)
|
||||
text = normalize_english(text)
|
||||
text = re.sub('[a-zA-Z]+', normalize_upper, text)
|
||||
|
||||
text = normalize_quote(text)
|
||||
text = normalize_number(text)
|
||||
|
||||
return text
|
||||
|
||||
def normalize_with_dictionary(text, dic):
|
||||
if any(key in text for key in dic.keys()):
|
||||
pattern = re.compile('|'.join(re.escape(key) for key in dic.keys()))
|
||||
return pattern.sub(lambda x: dic[x.group()], text)
|
||||
else:
|
||||
return text
|
||||
|
||||
def normalize_english(text):
|
||||
def fn(m):
|
||||
word = m.group()
|
||||
if word in english_dictionary:
|
||||
return english_dictionary.get(word)
|
||||
else:
|
||||
return word
|
||||
|
||||
text = re.sub("([A-Za-z]+)", fn, text)
|
||||
return text
|
||||
|
||||
def normalize_upper(text):
|
||||
text = text.group(0)
|
||||
|
||||
if all([char.isupper() for char in text]):
|
||||
return "".join(upper_to_kor[char] for char in text)
|
||||
else:
|
||||
return text
|
||||
|
||||
def normalize_quote(text):
|
||||
def fn(found_text):
|
||||
from nltk import sent_tokenize # NLTK doesn't along with multiprocessing
|
||||
|
||||
found_text = found_text.group()
|
||||
unquoted_text = found_text[1:-1]
|
||||
|
||||
sentences = sent_tokenize(unquoted_text)
|
||||
return " ".join(["'{}'".format(sent) for sent in sentences])
|
||||
|
||||
return re.sub(quote_checker, fn, text)
|
||||
|
||||
number_checker = "([+-]?\d[\d,]*)[\.]?\d*"
|
||||
count_checker = "(시|명|가지|살|마리|포기|송이|수|톨|통|점|개|벌|척|채|다발|그루|자루|줄|켤레|그릇|잔|마디|상자|사람|곡|병|판)"
|
||||
|
||||
def normalize_number(text):
|
||||
text = normalize_with_dictionary(text, unit_to_kor1)
|
||||
text = normalize_with_dictionary(text, unit_to_kor2)
|
||||
text = re.sub(number_checker + count_checker,
|
||||
lambda x: number_to_korean(x, True), text)
|
||||
text = re.sub(number_checker,
|
||||
lambda x: number_to_korean(x, False), text)
|
||||
return text
|
||||
|
||||
num_to_kor1 = [""] + list("일이삼사오육칠팔구")
|
||||
num_to_kor2 = [""] + list("만억조경해")
|
||||
num_to_kor3 = [""] + list("십백천")
|
||||
|
||||
#count_to_kor1 = [""] + ["하나","둘","셋","넷","다섯","여섯","일곱","여덟","아홉"]
|
||||
count_to_kor1 = [""] + ["한","두","세","네","다섯","여섯","일곱","여덟","아홉"]
|
||||
|
||||
count_tenth_dict = {
|
||||
"십": "열",
|
||||
"두십": "스물",
|
||||
"세십": "서른",
|
||||
"네십": "마흔",
|
||||
"다섯십": "쉰",
|
||||
"여섯십": "예순",
|
||||
"일곱십": "일흔",
|
||||
"여덟십": "여든",
|
||||
"아홉십": "아흔",
|
||||
}
|
||||
|
||||
|
||||
|
||||
def number_to_korean(num_str, is_count=False):
|
||||
if is_count:
|
||||
num_str, unit_str = num_str.group(1), num_str.group(2)
|
||||
else:
|
||||
num_str, unit_str = num_str.group(), ""
|
||||
|
||||
num_str = num_str.replace(',', '')
|
||||
num = ast.literal_eval(num_str)
|
||||
|
||||
if num == 0:
|
||||
return "영"
|
||||
|
||||
check_float = num_str.split('.')
|
||||
if len(check_float) == 2:
|
||||
digit_str, float_str = check_float
|
||||
elif len(check_float) >= 3:
|
||||
raise Exception(" [!] Wrong number format")
|
||||
else:
|
||||
digit_str, float_str = check_float[0], None
|
||||
|
||||
if is_count and float_str is not None:
|
||||
raise Exception(" [!] `is_count` and float number does not fit each other")
|
||||
|
||||
digit = int(digit_str)
|
||||
|
||||
if digit_str.startswith("-"):
|
||||
digit, digit_str = abs(digit), str(abs(digit))
|
||||
|
||||
kor = ""
|
||||
size = len(str(digit))
|
||||
tmp = []
|
||||
|
||||
for i, v in enumerate(digit_str, start=1):
|
||||
v = int(v)
|
||||
|
||||
if v != 0:
|
||||
if is_count:
|
||||
tmp += count_to_kor1[v]
|
||||
else:
|
||||
tmp += num_to_kor1[v]
|
||||
|
||||
tmp += num_to_kor3[(size - i) % 4]
|
||||
|
||||
if (size - i) % 4 == 0 and len(tmp) != 0:
|
||||
kor += "".join(tmp)
|
||||
tmp = []
|
||||
kor += num_to_kor2[int((size - i) / 4)]
|
||||
|
||||
if is_count:
|
||||
if kor.startswith("한") and len(kor) > 1:
|
||||
kor = kor[1:]
|
||||
|
||||
if any(word in kor for word in count_tenth_dict):
|
||||
kor = re.sub(
|
||||
'|'.join(count_tenth_dict.keys()),
|
||||
lambda x: count_tenth_dict[x.group()], kor)
|
||||
|
||||
if not is_count and kor.startswith("일") and len(kor) > 1:
|
||||
kor = kor[1:]
|
||||
|
||||
if float_str is not None:
|
||||
kor += "쩜 "
|
||||
kor += re.sub('\d', lambda x: num_to_kor[x.group()], float_str)
|
||||
|
||||
if num_str.startswith("+"):
|
||||
kor = "플러스 " + kor
|
||||
elif num_str.startswith("-"):
|
||||
kor = "마이너스 " + kor
|
||||
|
||||
return kor + unit_str
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_normalize(text):
|
||||
print(text)
|
||||
print(normalize(text))
|
||||
print("="*30)
|
||||
|
||||
test_normalize("JTBC는 JTBCs를 DY는 A가 Absolute")
|
||||
test_normalize("오늘(13일) 101마리 강아지가")
|
||||
test_normalize('"저돌"(猪突) 입니다.')
|
||||
test_normalize('비대위원장이 지난 1월 이런 말을 했습니다. “난 그냥 산돼지처럼 돌파하는 스타일이다”')
|
||||
test_normalize("지금은 -12.35%였고 종류는 5가지와 19가지, 그리고 55가지였다")
|
||||
test_normalize("JTBC는 TH와 K 양이 2017년 9월 12일 오후 12시에 24살이 된다")
|
13
text/symbols.py
Normal file
13
text/symbols.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
'''
|
||||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default is a set of ASCII characters that works well for English or text that has been run
|
||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||
'''
|
||||
from jamo import h2j, j2h
|
||||
from jamo.jamo import _jamo_char_to_hcj
|
||||
|
||||
from .korean import ALL_SYMBOLS, PAD, EOS
|
||||
|
||||
#symbols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
|
||||
symbols = ALL_SYMBOLS
|
322
train.py
Normal file
322
train.py
Normal file
|
@ -0,0 +1,322 @@
|
|||
import os
|
||||
import time
|
||||
import math
|
||||
import argparse
|
||||
import traceback
|
||||
import subprocess
|
||||
import numpy as np
|
||||
from jamo import h2j
|
||||
import tensorflow as tf
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
|
||||
from hparams import hparams, hparams_debug_string
|
||||
from models import create_model, get_most_recent_checkpoint
|
||||
|
||||
from utils import ValueWindow, prepare_dirs
|
||||
from utils import infolog, warning, plot, load_hparams
|
||||
from utils import get_git_revision_hash, get_git_diff, str2bool, parallel_run
|
||||
|
||||
from audio import save_audio, inv_spectrogram
|
||||
from text import sequence_to_text, text_to_sequence
|
||||
from datasets.datafeeder import DataFeeder, _prepare_inputs
|
||||
|
||||
log = infolog.log
|
||||
|
||||
|
||||
def create_batch_inputs_from_texts(texts):
|
||||
sequences = [text_to_sequence(text) for text in texts]
|
||||
|
||||
inputs = _prepare_inputs(sequences)
|
||||
input_lengths = np.asarray([len(x) for x in inputs], dtype=np.int32)
|
||||
|
||||
for idx, (seq, text) in enumerate(zip(inputs, texts)):
|
||||
recovered_text = sequence_to_text(seq, skip_eos_and_pad=True)
|
||||
if recovered_text != h2j(text):
|
||||
log(" [{}] {}".format(idx, text))
|
||||
log(" [{}] {}".format(idx, recovered_text))
|
||||
log("="*30)
|
||||
|
||||
return inputs, input_lengths
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean
|
||||
commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()[:10]
|
||||
log('Git commit: %s' % commit)
|
||||
return commit
|
||||
|
||||
|
||||
def add_stats(model, model2=None, scope_name='train'):
|
||||
with tf.variable_scope(scope_name) as scope:
|
||||
summaries = [
|
||||
tf.summary.scalar('loss_mel', model.mel_loss),
|
||||
tf.summary.scalar('loss_linear', model.linear_loss),
|
||||
tf.summary.scalar('loss', model.loss_without_coeff),
|
||||
]
|
||||
|
||||
if scope_name == 'train':
|
||||
gradient_norms = [tf.norm(grad) for grad in model.gradients if grad is not None]
|
||||
|
||||
summaries.extend([
|
||||
tf.summary.scalar('learning_rate', model.learning_rate),
|
||||
tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms)),
|
||||
])
|
||||
|
||||
if model2 is not None:
|
||||
with tf.variable_scope('gap_test-train') as scope:
|
||||
summaries.extend([
|
||||
tf.summary.scalar('loss_mel',
|
||||
model.mel_loss - model2.mel_loss),
|
||||
tf.summary.scalar('loss_linear',
|
||||
model.linear_loss - model2.linear_loss),
|
||||
tf.summary.scalar('loss',
|
||||
model.loss_without_coeff - model2.loss_without_coeff),
|
||||
])
|
||||
|
||||
return tf.summary.merge(summaries)
|
||||
|
||||
|
||||
def save_and_plot_fn(args, log_dir, step, loss, prefix):
|
||||
idx, (seq, spec, align) = args
|
||||
|
||||
audio_path = os.path.join(
|
||||
log_dir, '{}-step-{:09d}-audio{:03d}.wav'.format(prefix, step, idx))
|
||||
align_path = os.path.join(
|
||||
log_dir, '{}-step-{:09d}-align{:03d}.png'.format(prefix, step, idx))
|
||||
|
||||
waveform = inv_spectrogram(spec.T)
|
||||
save_audio(waveform, audio_path)
|
||||
|
||||
info_text = 'step={:d}, loss={:.5f}'.format(step, loss)
|
||||
plot.plot_alignment(
|
||||
align, align_path, info=info_text,
|
||||
text=sequence_to_text(seq,
|
||||
skip_eos_and_pad=True, combine_jamo=True))
|
||||
|
||||
def save_and_plot(sequences, spectrograms,
|
||||
alignments, log_dir, step, loss, prefix):
|
||||
|
||||
fn = partial(save_and_plot_fn,
|
||||
log_dir=log_dir, step=step, loss=loss, prefix=prefix)
|
||||
items = list(enumerate(zip(sequences, spectrograms, alignments)))
|
||||
|
||||
parallel_run(fn, items, parallel=False)
|
||||
log('Test finished for step {}.'.format(step))
|
||||
|
||||
|
||||
def train(log_dir, config):
|
||||
config.data_paths = config.data_paths
|
||||
|
||||
data_dirs = [os.path.join(data_path, "data") \
|
||||
for data_path in config.data_paths]
|
||||
num_speakers = len(data_dirs)
|
||||
config.num_test = config.num_test_per_speaker * num_speakers
|
||||
|
||||
if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]:
|
||||
raise Exception("[!] Unkown model_type for multi-speaker: {}".format(config.model_type))
|
||||
|
||||
commit = get_git_commit() if config.git else 'None'
|
||||
checkpoint_path = os.path.join(log_dir, 'model.ckpt')
|
||||
|
||||
log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash())
|
||||
log('='*50)
|
||||
log(' [*] dit diff:\n%s' % get_git_diff())
|
||||
log('='*50)
|
||||
log(' [*] Checkpoint path: %s' % checkpoint_path)
|
||||
log(' [*] Loading training data from: %s' % data_dirs)
|
||||
log(' [*] Using model: %s' % config.model_dir)
|
||||
log(hparams_debug_string())
|
||||
|
||||
# Set up DataFeeder:
|
||||
coord = tf.train.Coordinator()
|
||||
with tf.variable_scope('datafeeder') as scope:
|
||||
train_feeder = DataFeeder(
|
||||
coord, data_dirs, hparams, config, 32,
|
||||
data_type='train', batch_size=hparams.batch_size)
|
||||
test_feeder = DataFeeder(
|
||||
coord, data_dirs, hparams, config, 8,
|
||||
data_type='test', batch_size=config.num_test)
|
||||
|
||||
# Set up model:
|
||||
is_randomly_initialized = config.initialize_path is None
|
||||
global_step = tf.Variable(0, name='global_step', trainable=False)
|
||||
|
||||
with tf.variable_scope('model') as scope:
|
||||
model = create_model(hparams)
|
||||
model.initialize(
|
||||
train_feeder.inputs, train_feeder.input_lengths,
|
||||
num_speakers, train_feeder.speaker_id,
|
||||
train_feeder.mel_targets, train_feeder.linear_targets,
|
||||
train_feeder.loss_coeff,
|
||||
is_randomly_initialized=is_randomly_initialized)
|
||||
|
||||
model.add_loss()
|
||||
model.add_optimizer(global_step)
|
||||
train_stats = add_stats(model, scope_name='stats') # legacy
|
||||
|
||||
with tf.variable_scope('model', reuse=True) as scope:
|
||||
test_model = create_model(hparams)
|
||||
test_model.initialize(
|
||||
test_feeder.inputs, test_feeder.input_lengths,
|
||||
num_speakers, test_feeder.speaker_id,
|
||||
test_feeder.mel_targets, test_feeder.linear_targets,
|
||||
test_feeder.loss_coeff, rnn_decoder_test_mode=True,
|
||||
is_randomly_initialized=is_randomly_initialized)
|
||||
test_model.add_loss()
|
||||
|
||||
test_stats = add_stats(test_model, model, scope_name='test')
|
||||
test_stats = tf.summary.merge([test_stats, train_stats])
|
||||
|
||||
# Bookkeeping:
|
||||
step = 0
|
||||
time_window = ValueWindow(100)
|
||||
loss_window = ValueWindow(100)
|
||||
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)
|
||||
|
||||
sess_config = tf.ConfigProto(
|
||||
log_device_placement=False,
|
||||
allow_soft_placement=True)
|
||||
sess_config.gpu_options.allow_growth=True
|
||||
|
||||
# Train!
|
||||
#with tf.Session(config=sess_config) as sess:
|
||||
with tf.Session() as sess:
|
||||
try:
|
||||
summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if config.load_path:
|
||||
# Restore from a checkpoint if the user requested it.
|
||||
restore_path = get_most_recent_checkpoint(config.model_dir)
|
||||
saver.restore(sess, restore_path)
|
||||
log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
|
||||
elif config.initialize_path:
|
||||
restore_path = get_most_recent_checkpoint(config.initialize_path)
|
||||
saver.restore(sess, restore_path)
|
||||
log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
|
||||
|
||||
zero_step_assign = tf.assign(global_step, 0)
|
||||
sess.run(zero_step_assign)
|
||||
|
||||
start_step = sess.run(global_step)
|
||||
log('='*50)
|
||||
log(' [*] Global step is reset to {}'. \
|
||||
format(start_step))
|
||||
log('='*50)
|
||||
else:
|
||||
log('Starting new training run at commit: %s' % commit, slack=True)
|
||||
|
||||
start_step = sess.run(global_step)
|
||||
|
||||
train_feeder.start_in_session(sess, start_step)
|
||||
test_feeder.start_in_session(sess, start_step)
|
||||
|
||||
while not coord.should_stop():
|
||||
start_time = time.time()
|
||||
step, loss, opt = sess.run(
|
||||
[global_step, model.loss_without_coeff, model.optimize],
|
||||
feed_dict=model.get_dummy_feed_dict())
|
||||
|
||||
time_window.append(time.time() - start_time)
|
||||
loss_window.append(loss)
|
||||
|
||||
message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
|
||||
step, time_window.average, loss, loss_window.average)
|
||||
log(message, slack=(step % config.checkpoint_interval == 0))
|
||||
|
||||
if loss > 100 or math.isnan(loss):
|
||||
log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
|
||||
raise Exception('Loss Exploded')
|
||||
|
||||
if step % config.summary_interval == 0:
|
||||
log('Writing summary at step: %d' % step)
|
||||
|
||||
feed_dict = {
|
||||
**model.get_dummy_feed_dict(),
|
||||
**test_model.get_dummy_feed_dict()
|
||||
}
|
||||
summary_writer.add_summary(sess.run(
|
||||
test_stats, feed_dict=feed_dict), step)
|
||||
|
||||
if step % config.checkpoint_interval == 0:
|
||||
log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
|
||||
saver.save(sess, checkpoint_path, global_step=step)
|
||||
|
||||
if step % config.test_interval == 0:
|
||||
log('Saving audio and alignment...')
|
||||
num_test = config.num_test
|
||||
|
||||
fetches = [
|
||||
model.inputs[:num_test],
|
||||
model.linear_outputs[:num_test],
|
||||
model.alignments[:num_test],
|
||||
test_model.inputs[:num_test],
|
||||
test_model.linear_outputs[:num_test],
|
||||
test_model.alignments[:num_test],
|
||||
]
|
||||
feed_dict = {
|
||||
**model.get_dummy_feed_dict(),
|
||||
**test_model.get_dummy_feed_dict()
|
||||
}
|
||||
|
||||
sequences, spectrograms, alignments, \
|
||||
test_sequences, test_spectrograms, test_alignments = \
|
||||
sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
save_and_plot(sequences[:1], spectrograms[:1], alignments[:1],
|
||||
log_dir, step, loss, "train")
|
||||
save_and_plot(test_sequences, test_spectrograms, test_alignments,
|
||||
log_dir, step, loss, "test")
|
||||
|
||||
except Exception as e:
|
||||
log('Exiting due to exception: %s' % e, slack=True)
|
||||
traceback.print_exc()
|
||||
coord.request_stop(e)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--log_dir', default='logs')
|
||||
parser.add_argument('--data_paths', default='datasets/kr_example')
|
||||
parser.add_argument('--load_path', default=None)
|
||||
parser.add_argument('--initialize_path', default=None)
|
||||
|
||||
parser.add_argument('--num_test_per_speaker', type=int, default=2)
|
||||
parser.add_argument('--random_seed', type=int, default=123)
|
||||
parser.add_argument('--summary_interval', type=int, default=100)
|
||||
parser.add_argument('--test_interval', type=int, default=500)
|
||||
parser.add_argument('--checkpoint_interval', type=int, default=1000)
|
||||
parser.add_argument('--skip_path_filter',
|
||||
type=str2bool, default=False, help='Use only for debugging')
|
||||
|
||||
parser.add_argument('--slack_url',
|
||||
help='Slack webhook URL to get periodic reports.')
|
||||
parser.add_argument('--git', action='store_true',
|
||||
help='If set, verify that the client is clean.')
|
||||
|
||||
config = parser.parse_args()
|
||||
config.data_paths = config.data_paths.split(",")
|
||||
setattr(hparams, "num_speakers", len(config.data_paths))
|
||||
|
||||
prepare_dirs(config, hparams)
|
||||
|
||||
log_path = os.path.join(config.model_dir, 'train.log')
|
||||
infolog.init(log_path, config.model_dir, config.slack_url)
|
||||
|
||||
tf.set_random_seed(config.random_seed)
|
||||
|
||||
if any("krbook" not in data_path for data_path in config.data_paths) and \
|
||||
hparams.sample_rate != 20000:
|
||||
warning("Detect non-krbook dataset. Set sampling rate from {} to 20000".\
|
||||
format(hparams.sample_rate))
|
||||
|
||||
if config.load_path is not None and config.initialize_path is not None:
|
||||
raise Exception(" [!] Only one of load_path and initialize_path should be set")
|
||||
|
||||
train(config.model_dir, config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
BIN
utils/NanumBarunGothic.ttf
Normal file
BIN
utils/NanumBarunGothic.ttf
Normal file
Binary file not shown.
223
utils/__init__.py
Normal file
223
utils/__init__.py
Normal file
|
@ -0,0 +1,223 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import requests
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
from contextlib import closing
|
||||
from multiprocessing import Pool
|
||||
from collections import namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
from shutil import copyfile as copy_file
|
||||
|
||||
PARAMS_NAME = "params.json"
|
||||
|
||||
class ValueWindow():
|
||||
def __init__(self, window_size=100):
|
||||
self._window_size = window_size
|
||||
self._values = []
|
||||
|
||||
def append(self, x):
|
||||
self._values = self._values[-(self._window_size - 1):] + [x]
|
||||
|
||||
@property
|
||||
def sum(self):
|
||||
return sum(self._values)
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return len(self._values)
|
||||
|
||||
@property
|
||||
def average(self):
|
||||
return self.sum / max(1, self.count)
|
||||
|
||||
def reset(self):
|
||||
self._values = []
|
||||
|
||||
def prepare_dirs(config, hparams):
|
||||
if hasattr(config, "data_paths"):
|
||||
config.datasets = [
|
||||
os.path.basename(data_path) for data_path in config.data_paths]
|
||||
dataset_desc = "+".join(config.datasets)
|
||||
|
||||
if config.load_path:
|
||||
config.model_dir = config.load_path
|
||||
else:
|
||||
config.model_name = "{}_{}".format(dataset_desc, get_time())
|
||||
config.model_dir = os.path.join(config.log_dir, config.model_name)
|
||||
|
||||
for path in [config.log_dir, config.model_dir]:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
if config.load_path:
|
||||
load_hparams(hparams, config.model_dir)
|
||||
else:
|
||||
setattr(hparams, "num_speakers", len(config.datasets))
|
||||
|
||||
save_hparams(config.model_dir, hparams)
|
||||
copy_file("hparams.py", os.path.join(config.model_dir, "hparams.py"))
|
||||
|
||||
def makedirs(path):
|
||||
if not os.path.exists(path):
|
||||
print(" [*] Make directories : {}".format(path))
|
||||
os.makedirs(path)
|
||||
|
||||
def remove_file(path):
|
||||
if os.path.exists(path):
|
||||
print(" [*] Removed: {}".format(path))
|
||||
os.remove(path)
|
||||
|
||||
def backup_file(path):
|
||||
root, ext = os.path.splitext(path)
|
||||
new_path = "{}.backup_{}{}".format(root, get_time(), ext)
|
||||
|
||||
os.rename(path, new_path)
|
||||
print(" [*] {} has backup: {}".format(path, new_path))
|
||||
|
||||
def get_time():
|
||||
return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
def write_json(path, data):
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f, indent=4, sort_keys=True, ensure_ascii=False)
|
||||
|
||||
def load_json(path, as_class=False):
|
||||
with open(path) as f:
|
||||
content = f.read()
|
||||
content = re.sub(",\s*}", "}", content)
|
||||
content = re.sub(",\s*]", "]", content)
|
||||
|
||||
if as_class:
|
||||
data = json.loads(content, object_hook=\
|
||||
lambda data: namedtuple('Data', data.keys())(*data.values()))
|
||||
else:
|
||||
data = json.loads(content)
|
||||
return data
|
||||
|
||||
def save_hparams(model_dir, hparams):
|
||||
param_path = os.path.join(model_dir, PARAMS_NAME)
|
||||
|
||||
info = eval(hparams.to_json(). \
|
||||
replace('true', 'True').replace('false', 'False'))
|
||||
write_json(param_path, info)
|
||||
|
||||
print(" [*] MODEL dir: {}".format(model_dir))
|
||||
print(" [*] PARAM path: {}".format(param_path))
|
||||
|
||||
def load_hparams(hparams, load_path, skip_list=[]):
|
||||
path = os.path.join(load_path, PARAMS_NAME)
|
||||
|
||||
new_hparams = load_json(path)
|
||||
hparams_keys = vars(hparams).keys()
|
||||
|
||||
for key, value in new_hparams.items():
|
||||
if key in skip_list or key not in hparams_keys:
|
||||
print("Skip {} because it not exists".format(key))
|
||||
continue
|
||||
|
||||
if key not in ['job_name', 'num_workers', 'display', 'is_train', 'load_path'] or \
|
||||
key == "pointer_load_path":
|
||||
original_value = getattr(hparams, key)
|
||||
if original_value != value:
|
||||
print("UPDATE {}: {} -> {}".format(key, getattr(hparams, key), value))
|
||||
setattr(hparams, key, value)
|
||||
|
||||
def add_prefix(path, prefix):
|
||||
dir_path, filename = os.path.dirname(path), os.path.basename(path)
|
||||
return "{}/{}.{}".format(dir_path, prefix, filename)
|
||||
|
||||
def add_postfix(path, postfix):
|
||||
path_without_ext, ext = path.rsplit('.', 1)
|
||||
return "{}.{}.{}".format(path_without_ext, postfix, ext)
|
||||
|
||||
def remove_postfix(path):
|
||||
items = path.rsplit('.', 2)
|
||||
return items[0] + "." + items[2]
|
||||
|
||||
def parallel_run(fn, items, desc="", parallel=True):
|
||||
results = []
|
||||
|
||||
if parallel:
|
||||
with closing(Pool()) as pool:
|
||||
for out in tqdm(pool.imap_unordered(
|
||||
fn, items), total=len(items), desc=desc):
|
||||
if out is not None:
|
||||
results.append(out)
|
||||
else:
|
||||
for item in tqdm(items, total=len(items), desc=desc):
|
||||
out = fn(item)
|
||||
if out is not None:
|
||||
results.append(out)
|
||||
|
||||
return results
|
||||
|
||||
def which(program):
|
||||
if os.name == "nt" and not program.endswith(".exe"):
|
||||
program += ".exe"
|
||||
|
||||
envdir_list = [os.curdir] + os.environ["PATH"].split(os.pathsep)
|
||||
|
||||
for envdir in envdir_list:
|
||||
program_path = os.path.join(envdir, program)
|
||||
if os.path.isfile(program_path) and os.access(program_path, os.X_OK):
|
||||
return program_path
|
||||
|
||||
def get_encoder_name():
|
||||
if which("avconv"):
|
||||
return "avconv"
|
||||
elif which("ffmpeg"):
|
||||
return "ffmpeg"
|
||||
else:
|
||||
return "ffmpeg"
|
||||
|
||||
def download_with_url(url, dest_path, chunk_size=32*1024):
|
||||
with open(dest_path, "wb") as f:
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
for chunk in response.iter_content(chunk_size):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
return True
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ('true', '1')
|
||||
|
||||
def get_git_revision_hash():
|
||||
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode("utf-8")
|
||||
|
||||
def get_git_diff():
|
||||
return subprocess.check_output(['git', 'diff']).decode("utf-8")
|
||||
|
||||
def warning(msg):
|
||||
print("="*40)
|
||||
print(" [!] {}".format(msg))
|
||||
print("="*40)
|
||||
print()
|
||||
|
||||
def query_yes_no(question, default=None):
|
||||
# Code from https://stackoverflow.com/a/3041990
|
||||
valid = {"yes": True, "y": True, "ye": True,
|
||||
"no": False, "n": False}
|
||||
if default is None:
|
||||
prompt = " [y/n] "
|
||||
elif default == "yes":
|
||||
prompt = " [Y/n] "
|
||||
elif default == "no":
|
||||
prompt = " [y/N] "
|
||||
else:
|
||||
raise ValueError("invalid default answer: '%s'" % default)
|
||||
|
||||
while True:
|
||||
sys.stdout.write(question + prompt)
|
||||
choice = input().lower()
|
||||
if default is not None and choice == '':
|
||||
return valid[default]
|
||||
elif choice in valid:
|
||||
return valid[choice]
|
||||
else:
|
||||
sys.stdout.write("Please respond with 'yes' or 'no' "
|
||||
"(or 'y' or 'n').\n")
|
50
utils/infolog.py
Normal file
50
utils/infolog.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import atexit
|
||||
from datetime import datetime
|
||||
import json
|
||||
from threading import Thread
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
_format = '%Y-%m-%d %H:%M:%S.%f'
|
||||
_file = None
|
||||
_run_name = None
|
||||
_slack_url = None
|
||||
|
||||
|
||||
def init(filename, run_name, slack_url=None):
|
||||
global _file, _run_name, _slack_url
|
||||
_close_logfile()
|
||||
_file = open(filename, 'a')
|
||||
_file.write('\n-----------------------------------------------------------------\n')
|
||||
_file.write('Starting new training run\n')
|
||||
_file.write('-----------------------------------------------------------------\n')
|
||||
_run_name = run_name
|
||||
_slack_url = slack_url
|
||||
|
||||
|
||||
def log(msg, slack=False):
|
||||
print(msg)
|
||||
if _file is not None:
|
||||
_file.write('[%s] %s\n' % (datetime.now().strftime(_format)[:-3], msg))
|
||||
if slack and _slack_url is not None:
|
||||
Thread(target=_send_slack, args=(msg,)).start()
|
||||
|
||||
|
||||
def _close_logfile():
|
||||
global _file
|
||||
if _file is not None:
|
||||
_file.close()
|
||||
_file = None
|
||||
|
||||
|
||||
def _send_slack(msg):
|
||||
req = Request(_slack_url)
|
||||
req.add_header('Content-Type', 'application/json')
|
||||
urlopen(req, json.dumps({
|
||||
'username': 'tacotron',
|
||||
'icon_emoji': ':taco:',
|
||||
'text': '*%s*: %s' % (_run_name, msg)
|
||||
}).encode())
|
||||
|
||||
|
||||
atexit.register(_close_logfile)
|
61
utils/plot.py
Normal file
61
utils/plot.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
import os
|
||||
import matplotlib
|
||||
from jamo import h2j, j2hcj
|
||||
|
||||
matplotlib.use('Agg')
|
||||
matplotlib.rc('font', family="NanumBarunGothic")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from text import PAD, EOS
|
||||
from utils import add_postfix
|
||||
from text.korean import normalize
|
||||
|
||||
def plot(alignment, info, text):
|
||||
char_len, audio_len = alignment.shape # 145, 200
|
||||
|
||||
fig, ax = plt.subplots(figsize=(char_len/5, 5))
|
||||
im = ax.imshow(
|
||||
alignment.T,
|
||||
aspect='auto',
|
||||
origin='lower',
|
||||
interpolation='none')
|
||||
|
||||
xlabel = 'Encoder timestep'
|
||||
ylabel = 'Decoder timestep'
|
||||
|
||||
if info is not None:
|
||||
xlabel += '\n{}'.format(info)
|
||||
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
|
||||
if text:
|
||||
jamo_text = j2hcj(h2j(normalize(text)))
|
||||
pad = [PAD] * (char_len - len(jamo_text) - 1)
|
||||
|
||||
plt.xticks(range(char_len),
|
||||
[tok for tok in jamo_text] + [EOS] + pad)
|
||||
|
||||
if text is not None:
|
||||
while True:
|
||||
if text[-1] in [EOS, PAD]:
|
||||
text = text[:-1]
|
||||
else:
|
||||
break
|
||||
plt.title(text)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
def plot_alignment(
|
||||
alignment, path, info=None, text=None):
|
||||
|
||||
if text:
|
||||
tmp_alignment = alignment[:len(h2j(text)) + 2]
|
||||
|
||||
plot(tmp_alignment, info, text)
|
||||
plt.savefig(path, format='png')
|
||||
else:
|
||||
plot(alignment, info, text)
|
||||
plt.savefig(path, format='png')
|
||||
|
||||
print(" [*] Plot saved: {}".format(path))
|
60
web/static/css/main.css
Normal file
60
web/static/css/main.css
Normal file
|
@ -0,0 +1,60 @@
|
|||
@media screen and (min-width: 1452px) {
|
||||
.container {
|
||||
max-width: 1152px;
|
||||
width: 1152px;
|
||||
}
|
||||
}
|
||||
@media screen and (min-width: 1260px) {
|
||||
.container {
|
||||
max-width: 960px;
|
||||
width: 960px;
|
||||
}
|
||||
}
|
||||
@media screen and (min-width: 1068px) {
|
||||
.container {
|
||||
max-width: 768px;
|
||||
width: 768px;
|
||||
}
|
||||
}
|
||||
|
||||
.container {
|
||||
margin: 0 auto;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#wave {
|
||||
height: 100px;
|
||||
}
|
||||
|
||||
#waveform {
|
||||
display: none;
|
||||
}
|
||||
|
||||
#nav {
|
||||
position: fixed !important;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
z-index: 100;
|
||||
}
|
||||
|
||||
.card {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.columns {
|
||||
margin-left: 0rem;
|
||||
margin-right: 0rem;
|
||||
margin-top: 0rem;
|
||||
}
|
||||
|
||||
#text {
|
||||
font-size: 1.2em;
|
||||
padding: 0.7em 1em 0.7em 1em;
|
||||
background: transparent;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.dark {
|
||||
background-color: black;
|
||||
}
|
105
web/static/js/main.js
Normal file
105
web/static/js/main.js
Normal file
|
@ -0,0 +1,105 @@
|
|||
var sw;
|
||||
var wavesurfer;
|
||||
|
||||
var defaultSpeed = 0.03;
|
||||
var defaultAmplitude = 0.3;
|
||||
|
||||
var activeColors = [[32,133,252], [94,252,169], [253,71,103]];
|
||||
var inactiveColors = [[241,243,245], [206,212,218], [222,226,230], [173,181,189]];
|
||||
|
||||
function generate(ip, port, text, speaker_id) {
|
||||
$("#synthesize").addClass("is-loading");
|
||||
|
||||
var uri = 'http://' + ip + ':' + port
|
||||
var url = uri + '/generate?text=' + encodeURIComponent(text) + "&speaker_id=" + speaker_id;
|
||||
|
||||
fetch(url, {cache: 'no-cache', mode: 'cors'})
|
||||
.then(function(res) {
|
||||
if (!res.ok) throw Error(response.statusText)
|
||||
return res.blob()
|
||||
}).then(function(blob) {
|
||||
var url = URL.createObjectURL(blob);
|
||||
console.log(url);
|
||||
inProgress = false;
|
||||
wavesurfer.load(url);
|
||||
$("#synthesize").removeClass("is-loading");
|
||||
}).catch(function(err) {
|
||||
showWarning("에러가 발생했습니다");
|
||||
inProgress = false;
|
||||
$("#synthesize").removeClass("is-loading");
|
||||
});
|
||||
}
|
||||
|
||||
(function(window, document, undefined){
|
||||
window.onload = init;
|
||||
|
||||
function setDefaultColor(sw, isActive) {
|
||||
for (idx=0; idx < sw.curves.length; idx++) {
|
||||
var curve = sw.curves[idx];
|
||||
|
||||
if (isActive) {
|
||||
curve.color = activeColors[idx % activeColors.length];
|
||||
} else {
|
||||
curve.color = inactiveColors[idx % inactiveColors.length];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function init(){
|
||||
sw = new SiriWave9({
|
||||
amplitude: defaultAmplitude,
|
||||
container: document.getElementById('wave'),
|
||||
autostart: true,
|
||||
speed: defaultSpeed,
|
||||
style: 'ios9',
|
||||
});
|
||||
sw.setSpeed(defaultSpeed);
|
||||
setDefaultColor(sw, false);
|
||||
|
||||
wavesurfer = WaveSurfer.create({
|
||||
container: '#waveform',
|
||||
waveColor: 'violet',
|
||||
barWidth: 3,
|
||||
progressColor: 'purple'
|
||||
});
|
||||
|
||||
wavesurfer.on('ready', function () {
|
||||
this.width = wavesurfer.getDuration() *
|
||||
wavesurfer.params.minPxPerSec * wavesurfer.params.pixelRatio;
|
||||
this.peaks = wavesurfer.backend.getPeaks(width);
|
||||
|
||||
wavesurfer.play();
|
||||
});
|
||||
|
||||
wavesurfer.on('audioprocess', function () {
|
||||
var percent = wavesurfer.backend.getPlayedPercents();
|
||||
var height = this.peaks[parseInt(this.peaks.length * percent)];
|
||||
if (height > 0) {
|
||||
sw.setAmplitude(height*3);
|
||||
}
|
||||
});
|
||||
|
||||
wavesurfer.on('finish', function () {
|
||||
sw.setSpeed(defaultSpeed);
|
||||
sw.setAmplitude(defaultAmplitude);
|
||||
setDefaultColor(sw, false);
|
||||
});
|
||||
|
||||
$(document).on('click', "#synthesize", function() {
|
||||
synthesize();
|
||||
});
|
||||
|
||||
function synthesize() {
|
||||
var text = $("#text").val().trim();
|
||||
var text_length = text.length;
|
||||
|
||||
var speaker_id = $('input[name=id]:checked').val();
|
||||
var speaker = $('input[name=id]:checked').attr("speaker");
|
||||
|
||||
generate('0.0.0.0', 5000, text, speaker_id);
|
||||
|
||||
var lowpass = wavesurfer.backend.ac.createGain();
|
||||
wavesurfer.backend.setFilter(lowpass);
|
||||
}
|
||||
}
|
||||
})(window, document, undefined);
|
212
web/static/js/siriwave.js
Normal file
212
web/static/js/siriwave.js
Normal file
|
@ -0,0 +1,212 @@
|
|||
(function() {
|
||||
|
||||
////////////////////
|
||||
// SiriWave9Curve //
|
||||
////////////////////
|
||||
|
||||
function SiriWave9Curve(opt) {
|
||||
opt = opt || {};
|
||||
this.controller = opt.controller;
|
||||
this.color = opt.color;
|
||||
this.tick = 0;
|
||||
|
||||
this.respawn();
|
||||
}
|
||||
|
||||
SiriWave9Curve.prototype.respawn = function() {
|
||||
this.amplitude = 0.3 + Math.random() * 0.7;
|
||||
this.seed = Math.random();
|
||||
this.open_class = 2+(Math.random()*3)|0;
|
||||
};
|
||||
|
||||
SiriWave9Curve.prototype.equation = function(i) {
|
||||
var p = this.tick;
|
||||
var y = -1 * Math.abs(Math.sin(p)) * this.controller.amplitude * this.amplitude * this.controller.MAX * Math.pow(1/(1+Math.pow(this.open_class*i,2)),2);
|
||||
if (Math.abs(y) < 0.001) {
|
||||
this.respawn();
|
||||
}
|
||||
return y;
|
||||
};
|
||||
|
||||
SiriWave9Curve.prototype._draw = function(m) {
|
||||
this.tick += this.controller.speed * (1-0.5*Math.sin(this.seed*Math.PI));
|
||||
|
||||
var ctx = this.controller.ctx;
|
||||
ctx.beginPath();
|
||||
|
||||
var x_base = this.controller.width/2 + (-this.controller.width/4 + this.seed*(this.controller.width/2) );
|
||||
var y_base = this.controller.height/2;
|
||||
|
||||
var x, y, x_init;
|
||||
|
||||
var i = -3;
|
||||
while (i <= 3) {
|
||||
x = x_base + i * this.controller.width/4;
|
||||
y = y_base + (m * this.equation(i));
|
||||
x_init = x_init || x;
|
||||
ctx.lineTo(x, y);
|
||||
i += 0.01;
|
||||
}
|
||||
|
||||
var h = Math.abs(this.equation(0));
|
||||
var gradient = ctx.createRadialGradient(x_base, y_base, h*1.15, x_base, y_base, h * 0.3 );
|
||||
gradient.addColorStop(0, 'rgba(' + this.color.join(',') + ',0.4)');
|
||||
gradient.addColorStop(1, 'rgba(' + this.color.join(',') + ',0.2)');
|
||||
|
||||
ctx.fillStyle = gradient;
|
||||
|
||||
ctx.lineTo(x_init, y_base);
|
||||
ctx.closePath();
|
||||
|
||||
ctx.fill();
|
||||
};
|
||||
|
||||
SiriWave9Curve.prototype.draw = function() {
|
||||
this._draw(-1);
|
||||
this._draw(1);
|
||||
};
|
||||
|
||||
|
||||
//////////////
|
||||
// SiriWave //
|
||||
//////////////
|
||||
|
||||
function SiriWave9(opt) {
|
||||
opt = opt || {};
|
||||
|
||||
this.tick = 0;
|
||||
this.run = false;
|
||||
|
||||
// UI vars
|
||||
|
||||
this.ratio = opt.ratio || window.devicePixelRatio || 1;
|
||||
|
||||
this.width = this.ratio * (opt.width || 320);
|
||||
this.height = this.ratio * (opt.height || 100);
|
||||
this.MAX = this.height/2;
|
||||
|
||||
this.speed = 0.1;
|
||||
this.amplitude = opt.amplitude || 1;
|
||||
|
||||
// Interpolation
|
||||
|
||||
this.speedInterpolationSpeed = opt.speedInterpolationSpeed || 0.005;
|
||||
this.amplitudeInterpolationSpeed = opt.amplitudeInterpolationSpeed || 0.05;
|
||||
|
||||
this._interpolation = {
|
||||
speed: this.speed,
|
||||
amplitude: this.amplitude
|
||||
};
|
||||
|
||||
// Canvas
|
||||
|
||||
this.canvas = document.createElement('canvas');
|
||||
this.canvas.width = this.width;
|
||||
this.canvas.height = this.height;
|
||||
|
||||
if (opt.cover) {
|
||||
this.canvas.style.width = this.canvas.style.height = '100%';
|
||||
} else {
|
||||
this.canvas.style.width = (this.width / this.ratio) + 'px';
|
||||
this.canvas.style.height = (this.height / this.ratio) + 'px';
|
||||
}
|
||||
|
||||
this.container = opt.container || document.body;
|
||||
this.container.appendChild(this.canvas);
|
||||
|
||||
this.ctx = this.canvas.getContext('2d');
|
||||
|
||||
// Create curves
|
||||
|
||||
this.curves = [];
|
||||
for (var i = 0; i < SiriWave9.prototype.COLORS.length; i++) {
|
||||
var color = SiriWave9.prototype.COLORS[i];
|
||||
for (var j = 0; j < (3 * Math.random())|0; j++) {
|
||||
this.curves.push(new SiriWave9Curve({
|
||||
controller: this,
|
||||
color: color
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.autostart) {
|
||||
this.start();
|
||||
}
|
||||
}
|
||||
|
||||
SiriWave9.prototype._interpolate = function(propertyStr) {
|
||||
increment = this[ propertyStr + 'InterpolationSpeed' ];
|
||||
|
||||
if (Math.abs(this._interpolation[propertyStr] - this[propertyStr]) <= increment) {
|
||||
this[propertyStr] = this._interpolation[propertyStr];
|
||||
} else {
|
||||
if (this._interpolation[propertyStr] > this[propertyStr]) {
|
||||
this[propertyStr] += increment;
|
||||
} else {
|
||||
this[propertyStr] -= increment;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
SiriWave9.prototype._clear = function() {
|
||||
this.ctx.globalCompositeOperation = 'destination-out';
|
||||
this.ctx.fillRect(0, 0, this.width, this.height);
|
||||
this.ctx.globalCompositeOperation = 'lighter';
|
||||
};
|
||||
|
||||
SiriWave9.prototype._draw = function() {
|
||||
for (var i = 0, len = this.curves.length; i < len; i++) {
|
||||
this.curves[i].draw();
|
||||
}
|
||||
};
|
||||
|
||||
SiriWave9.prototype._startDrawCycle = function() {
|
||||
if (this.run === false) return;
|
||||
this._clear();
|
||||
|
||||
// Interpolate values
|
||||
this._interpolate('amplitude');
|
||||
this._interpolate('speed');
|
||||
|
||||
this._draw();
|
||||
this.phase = (this.phase + Math.PI*this.speed) % (2*Math.PI);
|
||||
|
||||
if (window.requestAnimationFrame) {
|
||||
window.requestAnimationFrame(this._startDrawCycle.bind(this));
|
||||
} else {
|
||||
setTimeout(this._startDrawCycle.bind(this), 20);
|
||||
}
|
||||
};
|
||||
|
||||
SiriWave9.prototype.start = function() {
|
||||
this.tick = 0;
|
||||
this.run = true;
|
||||
this._startDrawCycle();
|
||||
};
|
||||
|
||||
SiriWave9.prototype.stop = function() {
|
||||
this.tick = 0;
|
||||
this.run = false;
|
||||
};
|
||||
|
||||
SiriWave9.prototype.setSpeed = function(v, increment) {
|
||||
this._interpolation.speed = v;
|
||||
};
|
||||
|
||||
SiriWave9.prototype.setNoise = SiriWave9.prototype.setAmplitude = function(v) {
|
||||
this._interpolation.amplitude = Math.max(Math.min(v, 1), 0);
|
||||
};
|
||||
|
||||
SiriWave9.prototype.COLORS = [
|
||||
[32,133,252],
|
||||
[94,252,169],
|
||||
[253,71,103]
|
||||
];
|
||||
|
||||
if (typeof define === 'function' && define.amd) {
|
||||
define(function(){ return SiriWave9; });
|
||||
} else {
|
||||
window.SiriWave9 = SiriWave9;
|
||||
}
|
||||
|
||||
})();
|
74
web/templates/index.html
Normal file
74
web/templates/index.html
Normal file
|
@ -0,0 +1,74 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>D.Voice</title>
|
||||
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/bulma/0.5.1/css/bulma.min.css">
|
||||
<link rel="stylesheet" href="{{ url_for('static', filename='css/main.css') }}">
|
||||
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/underscore.js/1.8.3/underscore-min.js"></script>
|
||||
<script src="https://wavesurfer-js.org/dist/wavesurfer.min.js"></script>
|
||||
|
||||
<script src="{{ url_for('static', filename='js/siriwave.js') }}"></script>
|
||||
<script src="{{ url_for('static', filename='js/main.js') }}"></script>
|
||||
</head>
|
||||
<body class="layout-default">
|
||||
|
||||
<section class="hero is-fullheight dark">
|
||||
<div class="hero-body">
|
||||
<div class="container">
|
||||
<div class="section-body" onKeyPress="return checkSubmit(event)">
|
||||
<div class="field">
|
||||
<div class="control">
|
||||
<div class="columns">
|
||||
<div class="column"></div>
|
||||
<div class="column">
|
||||
<div id="wave"></div>
|
||||
</div>
|
||||
<div class="column"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<div class="control">
|
||||
<div id="waveform"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<div class="control has-text-centered">
|
||||
<label class="radio">
|
||||
<input type="radio" name="id" value="0" port="5000" checked>
|
||||
Speaker 1
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<div class="control has-icons-right">
|
||||
<textarea class="textarea" id="text" placeholder="{{ text }} "></textarea>
|
||||
<span class="icon is-small is-right" id="text-warning-icon" style="display:none">
|
||||
<i class="fa fa-warning"></i>
|
||||
</span>
|
||||
</div>
|
||||
<p class="help is-danger" id="text-warning" style="display:none">
|
||||
Wrong sentence
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="field has-text-centered">
|
||||
<button class="button is-white" id="synthesize">
|
||||
Synthesize
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</body>
|
||||
</html>
|
Loading…
Reference in a new issue