From 659cf3728a68afc02f25fe1968e2038072dafb3c Mon Sep 17 00:00:00 2001 From: Elizabeth Cray <elizabeth.cray@afs.com> Date: Wed, 16 Oct 2024 09:52:53 -0400 Subject: [PATCH] Transcription model override --- meatgrinder.py | 4 +++- transcribe.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/meatgrinder.py b/meatgrinder.py index 3a3b85a..5ac8387 100644 --- a/meatgrinder.py +++ b/meatgrinder.py @@ -10,6 +10,7 @@ parser.add_argument('--output', type=str, help='Path to the output file, will no parser.add_argument('--save-transcription', type=str, help='Path to save the transcription to') parser.add_argument('--skip-summary', action='store_true', help='Do not summarize') parser.add_argument('--force', action='store_true', help='Overwrite existing output file without asking') +parser.add_argument('--transcription-model', type=str, help='Override the default model used for transcription (Defaults to openai/whisper-large-v3)') args = parser.parse_args() if not args.input: @@ -32,7 +33,8 @@ if args.save_transcription: audio_file = media_convert.process(args.input) -transcription = transcribe.process(audio_file) +model = args.transcription_model or "openai/whisper-large-v3" +transcription = transcribe.process(audio_file, model) if os.path.isfile(audio_file): os.remove(audio_file) diff --git a/transcribe.py b/transcribe.py index c75b5ee..3dfeb33 100644 --- a/transcribe.py +++ b/transcribe.py @@ -1,14 +1,14 @@ import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline -def process(audio_path): +def process(audio_path, model_name): model = AutoModelForSpeechSeq2Seq.from_pretrained( - "openai/whisper-large-v3", + model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True, use_safetensors=True ) model.to("mps") - processor = AutoProcessor.from_pretrained("openai/whisper-large-v3") + processor = AutoProcessor.from_pretrained(model_name) pipe = pipeline( "automatic-speech-recognition", model=model,