From 659cf3728a68afc02f25fe1968e2038072dafb3c Mon Sep 17 00:00:00 2001
From: Elizabeth Cray <elizabeth.cray@afs.com>
Date: Wed, 16 Oct 2024 09:52:53 -0400
Subject: [PATCH] Transcription model override

---
 meatgrinder.py | 4 +++-
 transcribe.py  | 6 +++---
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/meatgrinder.py b/meatgrinder.py
index 3a3b85a..5ac8387 100644
--- a/meatgrinder.py
+++ b/meatgrinder.py
@@ -10,6 +10,7 @@ parser.add_argument('--output', type=str, help='Path to the output file, will no
 parser.add_argument('--save-transcription', type=str, help='Path to save the transcription to')
 parser.add_argument('--skip-summary', action='store_true', help='Do not summarize')
 parser.add_argument('--force', action='store_true', help='Overwrite existing output file without asking')
+parser.add_argument('--transcription-model', type=str, help='Override the default model used for transcription (Defaults to openai/whisper-large-v3)')
 args = parser.parse_args()
 
 if not args.input:
@@ -32,7 +33,8 @@ if args.save_transcription:
 
 audio_file = media_convert.process(args.input)
 
-transcription = transcribe.process(audio_file)
+model = args.transcription_model or "openai/whisper-large-v3"
+transcription = transcribe.process(audio_file, model)
 
 if os.path.isfile(audio_file):
     os.remove(audio_file)
diff --git a/transcribe.py b/transcribe.py
index c75b5ee..3dfeb33 100644
--- a/transcribe.py
+++ b/transcribe.py
@@ -1,14 +1,14 @@
 import torch
 from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
-def process(audio_path):
+def process(audio_path, model_name):
     model = AutoModelForSpeechSeq2Seq.from_pretrained(
-        "openai/whisper-large-v3",
+        model_name,
         torch_dtype=torch.float32,
         low_cpu_mem_usage=True,
         use_safetensors=True
     )
     model.to("mps")
-    processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
+    processor = AutoProcessor.from_pretrained(model_name)
     pipe = pipeline(
         "automatic-speech-recognition",
         model=model,