Transcription model override

This commit is contained in:
Elizabeth Cray 2024-10-16 09:52:53 -04:00
parent e8bdc317f5
commit 659cf3728a
2 changed files with 6 additions and 4 deletions

View File

@ -10,6 +10,7 @@ parser.add_argument('--output', type=str, help='Path to the output file, will no
parser.add_argument('--save-transcription', type=str, help='Path to save the transcription to') parser.add_argument('--save-transcription', type=str, help='Path to save the transcription to')
parser.add_argument('--skip-summary', action='store_true', help='Do not summarize') parser.add_argument('--skip-summary', action='store_true', help='Do not summarize')
parser.add_argument('--force', action='store_true', help='Overwrite existing output file without asking') parser.add_argument('--force', action='store_true', help='Overwrite existing output file without asking')
parser.add_argument('--transcription-model', type=str, help='Override the default model used for transcription (Defaults to openai/whisper-large-v3)')
args = parser.parse_args() args = parser.parse_args()
if not args.input: if not args.input:
@ -32,7 +33,8 @@ if args.save_transcription:
audio_file = media_convert.process(args.input) audio_file = media_convert.process(args.input)
transcription = transcribe.process(audio_file) model = args.transcription_model or "openai/whisper-large-v3"
transcription = transcribe.process(audio_file, model)
if os.path.isfile(audio_file): if os.path.isfile(audio_file):
os.remove(audio_file) os.remove(audio_file)

View File

@ -1,14 +1,14 @@
import torch import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
def process(audio_path): def process(audio_path, model_name):
model = AutoModelForSpeechSeq2Seq.from_pretrained( model = AutoModelForSpeechSeq2Seq.from_pretrained(
"openai/whisper-large-v3", model_name,
torch_dtype=torch.float32, torch_dtype=torch.float32,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
use_safetensors=True use_safetensors=True
) )
model.to("mps") model.to("mps")
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3") processor = AutoProcessor.from_pretrained(model_name)
pipe = pipeline( pipe = pipeline(
"automatic-speech-recognition", "automatic-speech-recognition",
model=model, model=model,