123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
|
# Code based on https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py
|
||
|
|
||
|
from __future__ import print_function
|
||
|
import os
|
||
|
import sys
|
||
|
import gzip
|
||
|
import json
|
||
|
import tarfile
|
||
|
import zipfile
|
||
|
import argparse
|
||
|
import requests
|
||
|
from tqdm import tqdm
|
||
|
from six.moves import urllib
|
||
|
|
||
|
from utils import query_yes_no
|
||
|
|
||
|
parser = argparse.ArgumentParser(description='Download model checkpoints.')
|
||
|
parser.add_argument('checkpoints', metavar='N', type=str, nargs='+', choices=['son', 'park'],
|
||
|
help='name of checkpoints to download [son, park]')
|
||
|
|
||
|
def download(url, dirpath):
|
||
|
filename = url.split('/')[-1]
|
||
|
filepath = os.path.join(dirpath, filename)
|
||
|
u = urllib.request.urlopen(url)
|
||
|
f = open(filepath, 'wb')
|
||
|
filesize = int(u.headers["Content-Length"])
|
||
|
print("Downloading: %s Bytes: %s" % (filename, filesize))
|
||
|
|
||
|
downloaded = 0
|
||
|
block_sz = 8192
|
||
|
status_width = 70
|
||
|
while True:
|
||
|
buf = u.read(block_sz)
|
||
|
if not buf:
|
||
|
print('')
|
||
|
break
|
||
|
else:
|
||
|
print('', end='\r')
|
||
|
downloaded += len(buf)
|
||
|
f.write(buf)
|
||
|
status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
|
||
|
('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
|
||
|
print(status, end='')
|
||
|
sys.stdout.flush()
|
||
|
f.close()
|
||
|
return filepath
|
||
|
|
||
|
def download_file_from_google_drive(id, destination):
|
||
|
URL = "https://docs.google.com/uc?export=download"
|
||
|
session = requests.Session()
|
||
|
|
||
|
response = session.get(URL, params={ 'id': id }, stream=True)
|
||
|
token = get_confirm_token(response)
|
||
|
|
||
|
if token:
|
||
|
params = { 'id' : id, 'confirm' : token }
|
||
|
response = session.get(URL, params=params, stream=True)
|
||
|
|
||
|
save_response_content(response, destination)
|
||
|
|
||
|
def get_confirm_token(response):
|
||
|
for key, value in response.cookies.items():
|
||
|
if key.startswith('download_warning'):
|
||
|
return value
|
||
|
return None
|
||
|
|
||
|
def save_response_content(response, destination, chunk_size=32*1024):
|
||
|
total_size = int(response.headers.get('content-length', 0))
|
||
|
with open(destination, "wb") as f:
|
||
|
for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
|
||
|
unit='B', unit_scale=True, desc=destination):
|
||
|
if chunk: # filter out keep-alive new chunks
|
||
|
f.write(chunk)
|
||
|
|
||
|
def unzip(filepath):
|
||
|
print("Extracting: " + filepath)
|
||
|
dirpath = os.path.dirname(filepath)
|
||
|
with zipfile.ZipFile(filepath) as zf:
|
||
|
zf.extractall(dirpath)
|
||
|
os.remove(filepath)
|
||
|
|
||
|
def download_checkpoint(checkpoint):
|
||
|
if checkpoint == "son":
|
||
|
save_path, drive_id = "son-20171015.tar.gz", "0B_7wC-DuR6ORcmpaY1A5V1AzZUU"
|
||
|
elif checkpoint == "park":
|
||
|
save_path, drive_id = "park-20171015.tar.gz", "0B_7wC-DuR6ORYjhlekl5bVlkQ2c"
|
||
|
else:
|
||
|
raise Exception(" [!] Unknown checkpoint: {}".format(checkpoint))
|
||
|
|
||
|
if os.path.exists(save_path):
|
||
|
print('[*] {} already exists'.format(save_path))
|
||
|
else:
|
||
|
download_file_from_google_drive(drive_id, save_path)
|
||
|
|
||
|
if save_path.endswith(".zip"):
|
||
|
zip_dir = ''
|
||
|
with zipfile.ZipFile(save_path) as zf:
|
||
|
zip_dir = zf.namelist()[0]
|
||
|
zf.extractall(dirpath)
|
||
|
os.remove(save_path)
|
||
|
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
|
||
|
elif save_path.endswith("tar.gz"):
|
||
|
tar = tarfile.open(save_path, "r:gz")
|
||
|
tar.extractall()
|
||
|
tar.close()
|
||
|
elif save_path.endswith("tar"):
|
||
|
tar = tarfile.open(save_path, "r:")
|
||
|
tar.extractall()
|
||
|
tar.close()
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
print(" [!] The pre-trained models are being made available for research purpose only")
|
||
|
print(" [!] 학습된 모델을 연구 이외의 목적으로 사용하는 것을 금지합니다.")
|
||
|
print()
|
||
|
|
||
|
if query_yes_no(" [?] Are you agree on this? 이에 동의하십니까?"):
|
||
|
if 'park' in args.checkpoints:
|
||
|
download_checkpoint('park')
|
||
|
if 'son' in args.checkpoints:
|
||
|
download_checkpoint('son')
|