142 lines
4.5 KiB
Python
142 lines
4.5 KiB
Python
import os
|
||
import librosa
|
||
import logging
|
||
from pathlib import Path
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 支持的音频文件扩展名
|
||
SUPPORTED_EXTENSIONS = {'.wav', '.mp3', '.flac', '.ogg', '.aiff', '.aif', '.m4a'}
|
||
|
||
|
||
def get_audio_sample_rate(file_path: str) -> tuple:
|
||
"""
|
||
获取单个音频文件的采样率
|
||
|
||
参数:
|
||
file_path: 音频文件路径
|
||
|
||
返回:
|
||
tuple: (文件路径, 采样率, 状态)
|
||
"""
|
||
try:
|
||
# 只获取采样率,不加载完整音频数据
|
||
_, sr = librosa.load(file_path, sr=None)
|
||
return (file_path, sr, "成功")
|
||
except Exception as e:
|
||
logger.error(f"处理文件 {file_path} 时出错: {str(e)}")
|
||
return (file_path, None, f"失败: {str(e)}")
|
||
|
||
|
||
def is_audio_file(file_path: str) -> bool:
|
||
"""检查文件是否为支持的音频文件"""
|
||
ext = os.path.splitext(file_path)[1].lower()
|
||
return ext in SUPPORTED_EXTENSIONS
|
||
|
||
|
||
def batch_calculate_sample_rates(input_dir: str, output_file: str = None, max_workers: int = 4) -> list:
|
||
"""
|
||
批量计算目录中所有音频文件的采样率
|
||
|
||
参数:
|
||
input_dir: 音频文件所在目录
|
||
output_file: 结果输出文件路径,None则不输出到文件
|
||
max_workers: 并行处理的最大线程数
|
||
|
||
返回:
|
||
list: 包含每个文件信息的字典列表
|
||
"""
|
||
if not os.path.isdir(input_dir):
|
||
logger.error(f"目录不存在: {input_dir}")
|
||
return []
|
||
|
||
# 收集所有音频文件路径
|
||
audio_files = []
|
||
for root, _, files in os.walk(input_dir):
|
||
for file in files:
|
||
file_path = os.path.join(root, file)
|
||
if is_audio_file(file_path):
|
||
audio_files.append(file_path)
|
||
|
||
logger.info(f"找到 {len(audio_files)} 个音频文件,开始处理...")
|
||
|
||
# 并行处理音频文件
|
||
results = []
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# 提交所有任务
|
||
futures = {executor.submit(get_audio_sample_rate, file_path): file_path
|
||
for file_path in audio_files}
|
||
|
||
# 获取结果
|
||
for future in as_completed(futures):
|
||
file_path = futures[future]
|
||
try:
|
||
path, sr, status = future.result()
|
||
results.append({
|
||
"file_path": path,
|
||
"sample_rate": sr,
|
||
"status": status
|
||
})
|
||
logger.info(f"处理完成: {os.path.basename(path)} - 采样率: {sr} Hz")
|
||
except Exception as e:
|
||
logger.error(f"获取结果时出错 {file_path}: {str(e)}")
|
||
|
||
# 按文件路径排序结果
|
||
results.sort(key=lambda x: x["file_path"])
|
||
|
||
# 保存结果到文件
|
||
if output_file:
|
||
try:
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
f.write("文件路径,采样率(Hz),状态\n")
|
||
for item in results:
|
||
f.write(f"{item['file_path']},{item['sample_rate'] or ''},{item['status']}\n")
|
||
logger.info(f"结果已保存到: {output_file}")
|
||
except Exception as e:
|
||
logger.error(f"保存结果到文件失败: {str(e)}")
|
||
|
||
return results
|
||
|
||
|
||
def main():
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='批量计算音频文件的采样率')
|
||
parser.add_argument('-o', '--output', help='结果输出CSV文件路径')
|
||
parser.add_argument('-w', '--workers', type=int, default=4,
|
||
help='并行处理的线程数,默认4')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 执行批量处理
|
||
results = batch_calculate_sample_rates(
|
||
input_dir="data/cat_sounds_4",
|
||
output_file=args.output,
|
||
max_workers=args.workers
|
||
)
|
||
|
||
# 统计结果
|
||
success_count = sum(1 for item in results if item["status"] == "成功")
|
||
fail_count = len(results) - success_count
|
||
|
||
logger.info(f"处理完成 - 成功: {success_count}, 失败: {fail_count}, 总计: {len(results)}")
|
||
|
||
# 如果没有指定输出文件,打印结果摘要
|
||
if not args.output and results:
|
||
print("\n结果摘要:")
|
||
for item in results[:10]: # 只显示前10个结果
|
||
print(f"{os.path.basename(item['file_path'])}: {item['sample_rate']} Hz ({item['status']})")
|
||
|
||
if len(results) > 10:
|
||
print(f"... 还有 {len(results) - 10} 个文件未显示")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|