Transcription model override
This commit is contained in:
parent
e8bdc317f5
commit
659cf3728a
@ -10,6 +10,7 @@ parser.add_argument('--output', type=str, help='Path to the output file, will no
|
|||||||
parser.add_argument('--save-transcription', type=str, help='Path to save the transcription to')
|
parser.add_argument('--save-transcription', type=str, help='Path to save the transcription to')
|
||||||
parser.add_argument('--skip-summary', action='store_true', help='Do not summarize')
|
parser.add_argument('--skip-summary', action='store_true', help='Do not summarize')
|
||||||
parser.add_argument('--force', action='store_true', help='Overwrite existing output file without asking')
|
parser.add_argument('--force', action='store_true', help='Overwrite existing output file without asking')
|
||||||
|
parser.add_argument('--transcription-model', type=str, help='Override the default model used for transcription (Defaults to openai/whisper-large-v3)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.input:
|
if not args.input:
|
||||||
@ -32,7 +33,8 @@ if args.save_transcription:
|
|||||||
|
|
||||||
audio_file = media_convert.process(args.input)
|
audio_file = media_convert.process(args.input)
|
||||||
|
|
||||||
transcription = transcribe.process(audio_file)
|
model = args.transcription_model or "openai/whisper-large-v3"
|
||||||
|
transcription = transcribe.process(audio_file, model)
|
||||||
|
|
||||||
if os.path.isfile(audio_file):
|
if os.path.isfile(audio_file):
|
||||||
os.remove(audio_file)
|
os.remove(audio_file)
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||||
def process(audio_path):
|
def process(audio_path, model_name):
|
||||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||||
"openai/whisper-large-v3",
|
model_name,
|
||||||
torch_dtype=torch.float32,
|
torch_dtype=torch.float32,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
use_safetensors=True
|
use_safetensors=True
|
||||||
)
|
)
|
||||||
model.to("mps")
|
model.to("mps")
|
||||||
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
|
processor = AutoProcessor.from_pretrained(model_name)
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
"automatic-speech-recognition",
|
"automatic-speech-recognition",
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
Reference in New Issue
Block a user