코딩/음소인식기
fairseq로 wav2vec2 디코딩하기
ppoqq
2022. 12. 30. 01:43
1. 파일 다운로드 및 preprocessing
- json 파일에서 오디오 파일 다운로드 및 다운샘플링
- json 파일 양식에 맞게 csv 파일 생성 (total.csv)
- fairseq data 생성 용 csv 파일 생성 (ch.csv)
import os
import soundfile as sf
import re
import json
import wget
from scipy.io import wavfile
import scipy.signal as sps
audio_save_dir = '/data/jihyeon1202/nia/niach6/audio_4/'
with open('/data/jihyeon1202/nia/zh.json', 'r') as f: entire_dic = json.load(f)
def down_sampling(path):
print('\ndownsample_input_path: ', path)
new_sr = 16000
wav_path = path.split('/')[-1]
print('\ndown_sampling_wav_path: ', wav_path)
new_path = '/data/jihyeon1202/nia/niach6/audio_4_16k/' + wav_path
sr, data = wavfile.read(path)
## data: sample의 수
## len(data) / sr: wave file이 몇초인지
## len(data) * float(new_sr) / sr: 새로운 sampling rate로 sampling했을 때의 sample 수
samples = round(len(data) * float(new_sr) / sr)
## sample 개수를 새로운 sample 수로 resampling
new_data = sps. resample(data, samples)
wavfile.write(new_path, new_sr, new_data)
new_lines = []
new_lines_csv = []
audio_16k_path = '/data/jihyeon1202/nia/niach6/audio_4_16k/'
numbers = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
for dic in entire_dic:
print(dic)
project_id = dic['projectId']
document_id = dic['documentId']
section_id = dic['sectionId']
segment_id = dic['segmentId']
index = dic['index']
lang = dic['lang']
audio_url = dic['audioUrl']
warning = null
tagging = null
wav_file = audio_url.split('/')[-1]
wav_path = audio_16k_path + wav_file
if os.path.exists(wav_path): pass
else:
wget.download(audio_url, wav_path)
down_sampling(wav_path)
ch_sent = dic['sentence']
ch_sent = ch_sent.lstrip().replace('\n', '')
ch_sent = re.sub('[-=+,#\?:^.@*\"※~ㆍ!』‘|\(\)\[\]`\'…》\”\“\’·。]', '', ch_sent)
for num in numbers:
if num in ch_sent:
warning = 'Sentence with numbers'
new_line = project_id + ',' + document_id + ',' + section_id + ',' + segment_id + ',' + index + ',' + lang + ',' + ch_sent + ',' + audio_url + ',' + warning + ',' + tagging + '\n'
new_line_csv = wav_path + ',' + ch_sent + '\n'
new_lines.append(new_line)
new_lines_csv.append(new_line_csv)
with open('/data/jihyeon1202/nia/niach6/total.csv', 'w') as f:
f.writelines(new_lines)
with open('/data/jihyeon1202/nia/niach6/ch.csv', 'w') as f:
f.writelines(new_lines_csv)
2. g2p 적용
with open('/data/jihyeon1202/nia/niach6/ch.csv', 'r') as f:
lines = f.readlines()
from g2pM import G2pM
model = G2pM() ## load g2p model
vowels = ['a', 'e', 'i', 'o', 'u'] ## 중국어 운모 정의
ph_tsv = []
for line in lines:
wav_path = line.split(',')[0]
ch_sent = line.split(',')[-1]
ch_sent = ch_sent.lstrip().replace('\n', '')
ch_sent = re.sub('[-=+,#\?:^.@*\"※~ㆍ!』‘|\(\)\[\]`\'…》\”\“\’·。]', '', ch_sent)
phoneme = model(ch_sent, tone=True, char_split=False)
split_cv = []
for j in range(len(phoneme)):
phoneme_str = str(phoneme[j])
phoneme_list = list(phoneme_str)
for k in range(len(phoneme_list)):
if phoneme_list[k] in vowels:
consonant = phoneme_str[:k]
vowel = phoneme_str[k:]
split_cv.append(consonant)
split_cv.append(vowel)
break
phoneme = ' '.join(p for p in split_cv)
new_line = wav_path + '\t' + phoneme + '\n'
ph_tsv.append(new_line)
print(len(ph_tsv))
with open('/data/jihyeon1202/nia/niach6/nia_6.tsv', 'w') as f:
f.writelines(ph_tsv)
with open('/data/jihyeon1202/nia/niach6/nia_6.tsv', 'r') as f:
lines = f.readlines()
print(len(lines))
3. test.tsv, test.phn, test.wrd, dict.phn.txt 파일 생성
tsv_list = []
phn_list = []
wrd_list = []
dict_list = []
for line in lines:
if 'wav' in line:
fname = line.split('\t')[0]
#print(fname)
sent = line.split('\t')[-1]
#print(sent)
length = sf.info(fname).frames
tsv_line = fname + '\t' + str(length) + '\n'
phn_line = sent
wrd_line = phn_line
phns = sent.split(' ')
phns[-1] = phns[-1][:-1]
#print(phns)
dic_list = []
for phn in phns:
dict_line = phn + ' 1\n'
if dict_line in dic_list: pass
else: dic_list.append(dict_line)
#print(tsv_line)
#print(phn_line)
#print(wrd_line)
#print(dic_list)
tsv_list.append(tsv_line)
phn_list.append(phn_line)
wrd_list.append(wrd_line)
dict_list.extend(dic_list)
else: pass
dict_list = list(set(dict_list))
dict_list.sort()
print('DICT: ', dict_list)
with open('/data/jihyeon1202/nia/fairseq_data/ch6/test.tsv', 'w') as f:
f.writelines(tsv_list)
with open('/data/jihyeon1202/nia/fairseq_data/ch6/test.phn', 'w') as f:
f.writelines(phn_list)
with open('/data/jihyeon1202/nia/fairseq_data/ch6/test.wrd', 'w') as f:
f.writelines(wrd_list)
with open('/data/jihyeon1202/nia/fairseq_data/ch6/dict.phn.txt', 'w') as f:
f.writelines(dict_list)
4. fairseq로 decoding
** docker에서 실행하는 경우 docker 절대경로로 바꿔줘야 함
python /workspace/fairseq/examples/speech_recognition/infer.py /path/to/tsv_phn_dict/data/ --task audio_finetuning --nbest 1 --path /path/to/your/checkponts/checkpoint_file.pt --gen-subset test --results-path /path/to/save/results/ --w2l-decoder viterbi --criterion ctc --labels ltr --max-tokens 1280000
예시)
python /workspace/fairseq/examples/speech_recognition/infer.py /workspace/data/nia/fairseq_data/ch6/ --task audio_finetuning --nbest 1 --path /workspace/data/nia/fairseq_data/checkpoints/checkpoint_best.pt --gen-subset test --results-path /workspace/data/nia/fairseq_data/ch6/results/ --w2l-decoder viterbi --criterion ctc --labels ltr --max-tokens 1280000
실행 결과가 위 코드에서 지정해준 --results-path에 생성됨
→ hypo.units-checkpoint_best.pt-test.txt ref.units-checkpoint_best.pt-test.txt 확인