Fangjun Kuang
Committed by GitHub

add more models for speaker diarization (#1440)

@@ -17,8 +17,9 @@ val segmentationModel = "segmentation.onnx" @@ -17,8 +17,9 @@ val segmentationModel = "segmentation.onnx"
17 17
18 // please download it from 18 // please download it from
19 // https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx 19 // https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  20 +// and rename it to embedding.onnx
20 // and move it to the assets folder 21 // and move it to the assets folder
21 -val embeddingModel = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" 22 +val embeddingModel = "embedding.onnx"
22 23
23 // in the end, your assets folder should look like below 24 // in the end, your assets folder should look like below
24 /* 25 /*
@@ -26,7 +27,7 @@ val embeddingModel = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx @@ -26,7 +27,7 @@ val embeddingModel = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
26 /Users/fangjun/open-source/sherpa-onnx/android/SherpaOnnxSpeakerDiarization/app/src/main/assets 27 /Users/fangjun/open-source/sherpa-onnx/android/SherpaOnnxSpeakerDiarization/app/src/main/assets
27 (py38) fangjuns-MacBook-Pro:assets fangjun$ ls -lh 28 (py38) fangjuns-MacBook-Pro:assets fangjun$ ls -lh
28 total 89048 29 total 89048
29 --rw-r--r-- 1 fangjun staff 38M Oct 12 20:28 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx 30 +-rw-r--r-- 1 fangjun staff 38M Oct 12 20:28 embedding.onnx
30 -rw-r--r-- 1 fangjun staff 5.7M Oct 12 20:28 segmentation.onnx 31 -rw-r--r-- 1 fangjun staff 5.7M Oct 12 20:28 segmentation.onnx
31 */ 32 */
32 33
@@ -63,4 +64,4 @@ object SpeakerDiarizationObject { @@ -63,4 +64,4 @@ object SpeakerDiarizationObject {
63 _sd = OfflineSpeakerDiarization(assetManager = assetManager, config = config) 64 _sd = OfflineSpeakerDiarization(assetManager = assetManager, config = config)
64 } 65 }
65 } 66 }
66 -}  
  67 +}
@@ -37,18 +37,20 @@ pushd ./android/SherpaOnnxSpeakerDiarization/app/src/main/assets/ @@ -37,18 +37,20 @@ pushd ./android/SherpaOnnxSpeakerDiarization/app/src/main/assets/
37 37
38 ls -lh 38 ls -lh
39 39
40 -model_name={{ model.model_name }}  
41 -short_name={{ model.short_name }} 40 +segmentation_model_name={{ model.segmentation.model_name }}
  41 +segmentation_short_name={{ model.segmentation.short_name }}
42 42
43 -curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/$model_name.tar.bz2  
44 -tar xvf $model_name.tar.bz2  
45 -rm $model_name.tar.bz2  
46 -mv $model_name/model.onnx segmentation.onnx  
47 -rm -rf $model_name 43 +embedding_model_name={{ model.embedding.model_name }}
  44 +embedding_short_name={{ model.embedding.short_name }}
48 45
49 -if [ ! -f 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then  
50 - curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx  
51 -fi 46 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/$segmentation_model_name.tar.bz2
  47 +tar xvf $segmentation_model_name.tar.bz2
  48 +rm $segmentation_model_name.tar.bz2
  49 +mv $segmentation_model_name/model.onnx segmentation.onnx
  50 +rm -rf $segmentation_model_name
  51 +
  52 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$embedding_model_name.onnx
  53 +mv $embedding_model_name.onnx embedding.onnx
52 54
53 echo "pwd: $PWD" 55 echo "pwd: $PWD"
54 ls -lh 56 ls -lh
@@ -74,12 +76,12 @@ for arch in arm64-v8a armeabi-v7a x86_64 x86; do @@ -74,12 +76,12 @@ for arch in arm64-v8a armeabi-v7a x86_64 x86; do
74 ./gradlew build 76 ./gradlew build
75 popd 77 popd
76 78
77 - mv android/SherpaOnnxSpeakerDiarization/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-${SHERPA_ONNX_VERSION}-$arch-speaker-diarization-$short_name-3dspeaker.apk 79 + mv android/SherpaOnnxSpeakerDiarization/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-${SHERPA_ONNX_VERSION}-$arch-speaker-diarization-$segmentation_short_name-$embedding_short_name.apk
78 ls -lh apks 80 ls -lh apks
79 rm -v ./android/SherpaOnnxSpeakerDiarization/app/src/main/jniLibs/$arch/*.so 81 rm -v ./android/SherpaOnnxSpeakerDiarization/app/src/main/jniLibs/$arch/*.so
80 done 82 done
81 83
82 -rm -rf ./android/SherpaOnnxSpeakerDiarization/app/src/main/assets/segmentation.onnx 84 +rm -rf ./android/SherpaOnnxSpeakerDiarization/app/src/main/assets/*.onnx
83 85
84 {% endfor %} 86 {% endfor %}
85 87
@@ -27,10 +27,22 @@ def get_args(): @@ -27,10 +27,22 @@ def get_args():
27 @dataclass 27 @dataclass
28 class SpeakerSegmentationModel: 28 class SpeakerSegmentationModel:
29 model_name: str 29 model_name: str
30 - short_name: str = "" 30 + short_name: str
31 31
32 32
33 -def get_models() -> List[SpeakerSegmentationModel]: 33 +@dataclass
  34 +class SpeakerEmbeddingModel:
  35 + model_name: str
  36 + short_name: str
  37 +
  38 +
  39 +@dataclass
  40 +class Model:
  41 + segmentation: SpeakerSegmentationModel
  42 + embedding: SpeakerEmbeddingModel
  43 +
  44 +
  45 +def get_segmentation_models() -> List[SpeakerSegmentationModel]:
34 models = [ 46 models = [
35 SpeakerSegmentationModel( 47 SpeakerSegmentationModel(
36 model_name="sherpa-onnx-pyannote-segmentation-3-0", 48 model_name="sherpa-onnx-pyannote-segmentation-3-0",
@@ -45,13 +57,33 @@ def get_models() -> List[SpeakerSegmentationModel]: @@ -45,13 +57,33 @@ def get_models() -> List[SpeakerSegmentationModel]:
45 return models 57 return models
46 58
47 59
  60 +def get_embedding_models() -> List[SpeakerEmbeddingModel]:
  61 + models = [
  62 + SpeakerSegmentationModel(
  63 + model_name="3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k",
  64 + short_name="3dspeaker",
  65 + ),
  66 + SpeakerSegmentationModel(
  67 + model_name="nemo_en_titanet_small",
  68 + short_name="nemo",
  69 + ),
  70 + ]
  71 + return models
  72 +
  73 +
48 def main(): 74 def main():
49 args = get_args() 75 args = get_args()
50 index = args.index 76 index = args.index
51 total = args.total 77 total = args.total
52 assert 0 <= index < total, (index, total) 78 assert 0 <= index < total, (index, total)
53 79
54 - all_model_list = get_models() 80 + segmentation_models = get_segmentation_models()
  81 + embedding_models = get_embedding_models()
  82 +
  83 + all_model_list = []
  84 + for s in segmentation_models:
  85 + for e in embedding_models:
  86 + all_model_list.append(Model(segmentation=s, embedding=e))
55 87
56 num_models = len(all_model_list) 88 num_models = len(all_model_list)
57 89