1 Star 0 Fork 1

aurae/deepspeech.pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
server.py 2.58 KB
一键复制 编辑 原始数据 按行查看 历史
brotheroak 提交于 2021-04-18 05:25 +08:00 . Update server.py (#633)
import logging
import os
from tempfile import NamedTemporaryFile
import hydra
import torch
from flask import Flask, request, jsonify
from hydra.core.config_store import ConfigStore
from deepspeech_pytorch.configs.inference_config import ServerConfig
from deepspeech_pytorch.inference import run_transcribe
from deepspeech_pytorch.loader.data_loader import SpectrogramParser
from deepspeech_pytorch.utils import load_model, load_decoder
app = Flask(__name__)
ALLOWED_EXTENSIONS = set(['.wav', '.mp3', '.ogg', '.webm'])
cs = ConfigStore.instance()
cs.store(name="config", node=ServerConfig)
@app.route('/transcribe', methods=['POST'])
def transcribe_file():
if request.method == 'POST':
res = {}
if 'file' not in request.files:
res['status'] = "error"
res['message'] = "audio file should be passed for the transcription"
return jsonify(res)
file = request.files['file']
filename = file.filename
_, file_extension = os.path.splitext(filename)
if file_extension.lower() not in ALLOWED_EXTENSIONS:
res['status'] = "error"
res['message'] = "{} is not supported format.".format(file_extension)
return jsonify(res)
with NamedTemporaryFile(suffix=file_extension) as tmp_saved_audio_file:
file.save(tmp_saved_audio_file.name)
logging.info('Transcribing file...')
transcription, _ = run_transcribe(
audio_path=tmp_saved_audio_file,
spect_parser=spect_parser,
model=model,
decoder=decoder,
device=device,
precision=config.model.precision
)
logging.info('File transcribed')
res['status'] = "OK"
res['transcription'] = transcription
return jsonify(res)
@hydra.main(config_name="config")
def main(cfg: ServerConfig):
global model, spect_parser, decoder, config, device
config = cfg
logging.getLogger().setLevel(logging.DEBUG)
logging.info('Setting up server...')
device = torch.device("cuda" if cfg.model.cuda else "cpu")
model = load_model(
device=device,
model_path=cfg.model.model_path
)
decoder = load_decoder(
labels=model.labels,
cfg=cfg.lm
)
spect_parser = SpectrogramParser(
audio_conf=model.spect_cfg,
normalize=True
)
logging.info('Server initialised')
app.run(
host=cfg.host,
port=cfg.port,
debug=True,
use_reloader=False
)
if __name__ == "__main__":
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/bigcash/deepspeech.pytorch.git
git@gitee.com:bigcash/deepspeech.pytorch.git
bigcash
deepspeech.pytorch
deepspeech.pytorch
master

搜索帮助