본문 바로가기

코딩/음소인식기

fairseq로 wav2vec2 디코딩하기

1. 파일 다운로드 및 preprocessing

  • json 파일에서 오디오 파일 다운로드 및 다운샘플링
  • json 파일 양식에 맞게 csv 파일 생성 (total.csv)
  • fairseq data 생성 용 csv 파일 생성 (ch.csv)
import os
import soundfile as sf
import re
import json
import wget
from scipy.io import wavfile
import scipy.signal as sps

audio_save_dir = '/data/jihyeon1202/nia/niach6/audio_4/'

with open('/data/jihyeon1202/nia/zh.json', 'r') as f: entire_dic = json.load(f)


def down_sampling(path):
    print('\ndownsample_input_path: ', path)
    new_sr = 16000
    wav_path = path.split('/')[-1]
    print('\ndown_sampling_wav_path: ', wav_path)
    new_path = '/data/jihyeon1202/nia/niach6/audio_4_16k/' + wav_path
    
    sr, data = wavfile.read(path)
    ## data: sample의 수
    ## len(data) / sr: wave file이 몇초인지
    ## len(data) * float(new_sr) / sr: 새로운 sampling rate로 sampling했을 때의 sample 수
    samples = round(len(data) * float(new_sr) / sr)
    ## sample 개수를 새로운 sample 수로 resampling
    new_data = sps. resample(data, samples)
    
    wavfile.write(new_path, new_sr, new_data)
    
    
new_lines = []
new_lines_csv = []
audio_16k_path = '/data/jihyeon1202/nia/niach6/audio_4_16k/'

numbers = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

for dic in entire_dic:
    print(dic)
    
    project_id = dic['projectId']
    document_id = dic['documentId']
    section_id = dic['sectionId']
    segment_id = dic['segmentId']
    index = dic['index']
    lang = dic['lang']
    audio_url = dic['audioUrl']
    warning = null
    tagging = null
    wav_file = audio_url.split('/')[-1]
    wav_path = audio_16k_path + wav_file
    
    if os.path.exists(wav_path): pass
    else:
        wget.download(audio_url, wav_path)
        down_sampling(wav_path)
    
    ch_sent = dic['sentence']
    ch_sent = ch_sent.lstrip().replace('\n', '')
    ch_sent = re.sub('[-=+,#\?:^.@*\"※~ㆍ!』‘|\(\)\[\]`\'…》\”\“\’·。]', '', ch_sent)
    
    for num in numbers:
    	if num in ch_sent:
        	warning = 'Sentence with numbers'
    
    new_line = project_id + ',' + document_id + ',' + section_id + ',' + segment_id + ',' + index + ',' + lang + ',' + ch_sent + ',' + audio_url + ',' + warning + ',' + tagging + '\n'
    new_line_csv = wav_path + ',' + ch_sent + '\n'
    new_lines.append(new_line)
    new_lines_csv.append(new_line_csv)
    
    
with open('/data/jihyeon1202/nia/niach6/total.csv', 'w') as f:
    f.writelines(new_lines)

with open('/data/jihyeon1202/nia/niach6/ch.csv', 'w') as f:
    f.writelines(new_lines_csv)

 

2. g2p 적용

with open('/data/jihyeon1202/nia/niach6/ch.csv', 'r') as f:
	lines = f.readlines()
    
from g2pM import G2pM
model = G2pM() ## load g2p model
vowels = ['a', 'e', 'i', 'o', 'u'] ## 중국어 운모 정의

ph_tsv = []

for line in lines:
    wav_path = line.split(',')[0]
    ch_sent = line.split(',')[-1]
    ch_sent = ch_sent.lstrip().replace('\n', '')
    ch_sent = re.sub('[-=+,#\?:^.@*\"※~ㆍ!』‘|\(\)\[\]`\'…》\”\“\’·。]', '', ch_sent)
    phoneme = model(ch_sent, tone=True, char_split=False)
    split_cv = []
    for j in range(len(phoneme)):
        phoneme_str = str(phoneme[j])
        phoneme_list = list(phoneme_str)
            
        for k in range(len(phoneme_list)):
            if phoneme_list[k] in vowels:
                consonant = phoneme_str[:k]
                vowel = phoneme_str[k:]
                split_cv.append(consonant)
                split_cv.append(vowel)
                break
    phoneme = ' '.join(p for p in split_cv)
    
    new_line = wav_path + '\t' + phoneme + '\n'
    ph_tsv.append(new_line)
    
print(len(ph_tsv))

with open('/data/jihyeon1202/nia/niach6/nia_6.tsv', 'w') as f:
    f.writelines(ph_tsv)
    
with open('/data/jihyeon1202/nia/niach6/nia_6.tsv', 'r') as f:
    lines = f.readlines()
    
print(len(lines))

 

3. test.tsv, test.phn, test.wrd, dict.phn.txt 파일 생성

tsv_list = []
phn_list = []
wrd_list = []
dict_list = []

for line in lines:
    if 'wav' in line:
        fname = line.split('\t')[0]
        #print(fname)
        sent = line.split('\t')[-1]
        #print(sent)
        length = sf.info(fname).frames
        
        tsv_line = fname + '\t' + str(length) + '\n'
        phn_line = sent 
        wrd_line = phn_line
        
        phns = sent.split(' ')
        phns[-1] = phns[-1][:-1]
        #print(phns)
        dic_list = []
        for phn in phns:
            dict_line = phn + ' 1\n'
            if dict_line in dic_list: pass
            else: dic_list.append(dict_line)
        #print(tsv_line)
        #print(phn_line)
        #print(wrd_line)
        #print(dic_list)
        
        tsv_list.append(tsv_line)
        phn_list.append(phn_line)
        wrd_list.append(wrd_line)
        dict_list.extend(dic_list)
    
    else: pass
    
dict_list = list(set(dict_list))
dict_list.sort()
print('DICT: ', dict_list)

with open('/data/jihyeon1202/nia/fairseq_data/ch6/test.tsv', 'w') as f:
    f.writelines(tsv_list)
with open('/data/jihyeon1202/nia/fairseq_data/ch6/test.phn', 'w') as f:
    f.writelines(phn_list)
with open('/data/jihyeon1202/nia/fairseq_data/ch6/test.wrd', 'w') as f:
    f.writelines(wrd_list)
with open('/data/jihyeon1202/nia/fairseq_data/ch6/dict.phn.txt', 'w') as f:
    f.writelines(dict_list)

 

4. fairseq로 decoding

** docker에서 실행하는 경우 docker 절대경로로 바꿔줘야 함

python /workspace/fairseq/examples/speech_recognition/infer.py /path/to/tsv_phn_dict/data/ --task audio_finetuning --nbest 1 --path /path/to/your/checkponts/checkpoint_file.pt --gen-subset test --results-path /path/to/save/results/ --w2l-decoder viterbi --criterion ctc --labels ltr --max-tokens 1280000

예시)

python /workspace/fairseq/examples/speech_recognition/infer.py /workspace/data/nia/fairseq_data/ch6/ --task audio_finetuning --nbest 1 --path /workspace/data/nia/fairseq_data/checkpoints/checkpoint_best.pt --gen-subset test --results-path /workspace/data/nia/fairseq_data/ch6/results/ --w2l-decoder viterbi --criterion ctc --labels ltr --max-tokens 1280000

실행 결과가 위 코드에서 지정해준 --results-path에 생성됨

→ hypo.units-checkpoint_best.pt-test.txt  ref.units-checkpoint_best.pt-test.txt 확인

'코딩 > 음소인식기' 카테고리의 다른 글

fairseq로 wav2vec2 finetuning하기  (0) 2022.11.21