Subsync first implementation (only after download/upload).

This commit is contained in:
Louis Vézina 2020-06-10 12:04:54 -04:00
parent f79faaa5c5
commit c6548c06b7
271 changed files with 56072 additions and 15 deletions

View file

@ -150,6 +150,13 @@ defaults = {
'titlovi': {
'username': '',
'password': ''
},
'subsync': {
'use_subsync': 'False',
'use_subsync_threshold': 'False',
'subsync_threshold': '90',
'use_subsync_movie_threshold': 'False',
'subsync_movie_threshold': '70'
}
}

View file

@ -6,11 +6,6 @@ from knowit import api
from utils import get_binary
class NotMKVAndNoFFprobe(Exception):
pass
class FFprobeError(Exception):
pass
class EmbeddedSubsReader:
def __init__(self):

View file

@ -27,6 +27,7 @@ from utils import history_log, history_log_movie, get_binary
from notifier import send_notifications, send_notifications_movie
from get_providers import get_providers, get_providers_auth, provider_throttle, provider_pool
from knowit import api
from subsyncer import subsync
from database import database, dict_mapper
from analytics import track_event
@ -215,15 +216,20 @@ def download_subtitle(path, language, audio_language, hi, forced, providers, pro
action = "upgraded"
else:
action = "downloaded"
percent_score = round(subtitle.score * 100 / max_score, 2)
message = downloaded_language + is_forced_string + " subtitles " + action + " from " + \
downloaded_provider + " with a score of " + str(round(subtitle.score * 100 / max_score, 2))\
+ "%."
downloaded_provider + " with a score of " + str(percent_score) + "%."
sync_result = sync_subtitles(video_path=path, srt_path=downloaded_path,
srt_lang=downloaded_language_code3, media_type=media_type,
percent_score=percent_score)
if sync_result:
message += " The subtitles file have been synced."
if use_postprocessing is True:
command = pp_replace(postprocessing_cmd, path, downloaded_path, downloaded_language,
downloaded_language_code2, downloaded_language_code3, audio_language,
audio_language_code2, audio_language_code3, subtitle.language.forced)
percent_score = round(subtitle.score * 100 / max_score, 2)
if media_type == 'series':
use_pp_threshold = settings.general.getboolean('use_postprocessing_threshold')
@ -441,14 +447,19 @@ def manual_download_subtitle(path, language, audio_language, hi, forced, subtitl
downloaded_path = saved_subtitle.storage_path
logging.debug('BAZARR Subtitles file saved to disk: ' + downloaded_path)
is_forced_string = " forced" if subtitle.language.forced else ""
message = downloaded_language + is_forced_string + " subtitles downloaded from " + downloaded_provider + " with a score of " + str(
score) + "% using manual search."
message = downloaded_language + is_forced_string + " subtitles downloaded from " + \
downloaded_provider + " with a score of " + str(score) + "% using manual search."
sync_result = sync_subtitles(video_path=path, srt_path=downloaded_path,
srt_lang=downloaded_language_code3, media_type=media_type,
percent_score=score)
if sync_result:
message += " The subtitles file have been synced."
if use_postprocessing is True:
command = pp_replace(postprocessing_cmd, path, downloaded_path, downloaded_language,
downloaded_language_code2, downloaded_language_code3, audio_language,
audio_language_code2, audio_language_code3, subtitle.language.forced)
percent_score = round(subtitle.score * 100 / max_score, 2)
if media_type == 'series':
use_pp_threshold = settings.general.getboolean('use_postprocessing_threshold')
@ -457,7 +468,7 @@ def manual_download_subtitle(path, language, audio_language, hi, forced, subtitl
use_pp_threshold = settings.general.getboolean('use_postprocessing_threshold_movie')
pp_threshold = settings.general.postprocessing_threshold_movie
if not use_pp_threshold or (use_pp_threshold and percent_score < float(pp_threshold)):
if not use_pp_threshold or (use_pp_threshold and score < float(pp_threshold)):
postprocessing(command, path)
else:
logging.debug("BAZARR post-processing skipped because subtitles score isn't below this "
@ -559,6 +570,10 @@ def manual_upload_subtitle(path, language, forced, title, scene_name, media_type
audio_language_code2 = alpha2_from_language(audio_language)
audio_language_code3 = alpha3_from_language(audio_language)
sync_result = sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code3,
media_type=media_type, percent_score=100)
if sync_result:
message += " The subtitles file have been synced."
if use_postprocessing is True:
command = pp_replace(postprocessing_cmd, path, subtitle_path, uploaded_language,
@ -985,6 +1000,11 @@ def refine_from_ffprobe(path, video):
if 'codec' in data['audio'][0]:
if not video.audio_codec:
video.audio_codec = data['audio'][0]['codec']
for track in data['audio']:
if 'language' in track:
video.audio_languages.add(track['language'].alpha3)
return video
def upgrade_subtitles():
@ -1197,3 +1217,21 @@ def postprocessing(command, path):
'BAZARR Post-processing result for file ' + path + ' : Nothing returned from command execution')
else:
logging.info('BAZARR Post-processing result for file ' + path + ' : ' + out)
def sync_subtitles(video_path, srt_path, srt_lang, media_type, percent_score):
if settings.subsync.use_subsync:
if media_type == 'series':
use_subsync_threshold = settings.subsync.getboolean('use_subsync_threshold')
subsync_threshold = settings.subsync.subsync_threshold
else:
use_subsync_threshold = settings.subsync.getboolean('use_subsync_movie_threshold')
subsync_threshold = settings.subsync.subsync_movie_threshold
if not use_subsync_threshold or (use_subsync_threshold and percent_score < float(subsync_threshold)):
subsync.sync(video_path=video_path, srt_path=srt_path, srt_lang=srt_lang)
return True
else:
logging.debug("BAZARR subsync skipped because subtitles score isn't below this "
"threshold value: " + subsync_threshold + "%")
return False

View file

@ -4,6 +4,8 @@ import os
import rarfile
import json
import hashlib
import sys
import subprocess
from config import settings, configure_captcha_func
from get_args import args
@ -41,6 +43,27 @@ if not os.path.exists(os.path.join(args.config_dir, 'cache')):
configure_logging(settings.general.getboolean('debug') or args.debug)
import logging
# deploy requirements.txt
if not args.no_update:
try:
import lxml, numpy
except ImportError:
try:
import pip
except ImportError:
logging.info('BAZARR unable to install requirements (pip not installed).')
else:
logging.info('BAZARR installing requirements...')
subprocess.call([sys.executable, '-m', 'pip', 'install', '--user', '-r',
os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logging.info('BAZARR requirements installed.')
try:
from server import webserver
webserver.restart()
except:
logging.info('BAZARR unable to restart. Please do it manually.')
# create random api_key if there's none in config.ini
if not settings.auth.apikey or settings.auth.apikey.startswith("b'"):
from binascii import hexlify

View file

@ -82,6 +82,10 @@ def configure_logging(debug=False):
logging.getLogger("apprise").setLevel(logging.DEBUG)
logging.getLogger("engineio.server").setLevel(logging.DEBUG)
logging.getLogger("socketio.server").setLevel(logging.DEBUG)
logging.getLogger("ffsubsync.subtitle_parser").setLevel(logging.DEBUG)
logging.getLogger("ffsubsync.speech_transformers").setLevel(logging.DEBUG)
logging.getLogger("ffsubsync.ffsubsync").setLevel(logging.DEBUG)
logging.getLogger("srt").setLevel(logging.DEBUG)
logging.debug('Bazarr version: %s', os.environ["BAZARR_VERSION"])
logging.debug('Bazarr branch: %s', settings.general.branch)
logging.debug('Operating system: %s', platform.platform())
@ -94,6 +98,10 @@ def configure_logging(debug=False):
logging.getLogger("subzero").setLevel(logging.ERROR)
logging.getLogger("engineio.server").setLevel(logging.ERROR)
logging.getLogger("socketio.server").setLevel(logging.ERROR)
logging.getLogger("ffsubsync.subtitle_parser").setLevel(logging.ERROR)
logging.getLogger("ffsubsync.speech_transformers").setLevel(logging.ERROR)
logging.getLogger("ffsubsync.ffsubsync").setLevel(logging.ERROR)
logging.getLogger("srt").setLevel(logging.ERROR)
logging.getLogger("waitress").setLevel(logging.CRITICAL)
logging.getLogger("knowit").setLevel(logging.CRITICAL)

93
bazarr/subsyncer.py Normal file
View file

@ -0,0 +1,93 @@
import logging
import os
from ffsubsync.ffsubsync import run
from ffsubsync.constants import *
from knowit import api
from utils import get_binary
class SubSyncer:
def __init__(self):
self.reference = None
self.srtin = None
self.reference_stream = None
self.overwrite_input = True
self.ffmpeg_path = None
# unused attributes
self.encoding = DEFAULT_ENCODING
self.vlc_mode = None
self.make_test_case = None
self.gui_mode = None
self.srtout = None
self.vad = 'subs_then_auditok'
self.reference_encoding = None
self.frame_rate = DEFAULT_FRAME_RATE
self.start_seconds = DEFAULT_START_SECONDS
self.no_fix_framerate = None
self.serialize_speech = None
self.max_offset_seconds = DEFAULT_MAX_OFFSET_SECONDS
self.merge_with_reference = None
self.output_encoding = 'same'
def sync(self, video_path, srt_path, srt_lang):
self.reference = video_path
self.srtin = srt_path
self.srtout = None
ffprobe_exe = get_binary('ffprobe')
if not ffprobe_exe:
logging.debug('BAZARR FFprobe not found!')
return
else:
logging.debug('BAZARR FFprobe used is %s', ffprobe_exe)
api.initialize({'provider': 'ffmpeg', 'ffmpeg': ffprobe_exe})
data = api.know(self.reference)
if 'subtitle' in data:
for i, embedded_subs in enumerate(data['subtitle']):
if 'language' in embedded_subs:
language = embedded_subs['language'].alpha3
if language == "eng":
self.reference_stream = "s:{}".format(i)
break
if not self.reference_stream:
self.reference_stream = "s:0"
elif 'audio' in data:
audio_tracks = data['audio']
for i, audio_track in enumerate(audio_tracks):
if 'language' in audio_track:
language = audio_track['language'].alpha3
if language == srt_lang:
self.reference_stream = "a:{}".format(i)
break
if not self.reference_stream:
audio_tracks = data['audio']
for i, audio_track in enumerate(audio_tracks):
if 'language' in audio_track:
language = audio_track['language'].alpha3
if language == "eng":
self.reference_stream = "a:{}".format(i)
break
if not self.reference_stream:
self.reference_stream = "a:0"
else:
raise NoAudioTrack
ffmpeg_exe = get_binary('ffmpeg')
if not ffprobe_exe:
logging.debug('BAZARR FFmpeg not found!')
return
else:
logging.debug('BAZARR FFmpeg used is %s', ffmpeg_exe)
self.ffmpeg_path = os.path.dirname(ffmpeg_exe)
run(self)
class NoAudioTrack(Exception):
"""Exception raised if no audio track can be found in video file."""
pass
subsync = SubSyncer()

View file

@ -17,6 +17,10 @@ import datetime
import glob
class BinaryNotFound(Exception):
pass
def history_log(action, sonarr_series_id, sonarr_episode_id, description, video_path=None, language=None, provider=None,
score=None):
database.execute("INSERT INTO table_history (action, sonarrSeriesId, sonarrEpisodeId, timestamp, description,"
@ -42,15 +46,22 @@ def get_binary(name):
if installed_exe and os.path.isfile(installed_exe):
return installed_exe
else:
if name == 'ffprobe':
dir_name = 'ffmpeg'
else:
dir_name = name
if platform.system() == "Windows": # Windows
exe = os.path.abspath(os.path.join(binaries_dir, "Windows", "i386", name, "%s.exe" % name))
exe = os.path.abspath(os.path.join(binaries_dir, "Windows", "i386", dir_name, "%s.exe" % name))
elif platform.system() == "Darwin": # MacOSX
exe = os.path.abspath(os.path.join(binaries_dir, "MacOSX", "i386", name, name))
exe = os.path.abspath(os.path.join(binaries_dir, "MacOSX", "i386", dir_name, name))
elif platform.system() == "Linux": # Linux
exe = os.path.abspath(os.path.join(binaries_dir, "Linux", platform.machine(), name, name))
exe = os.path.abspath(os.path.join(binaries_dir, "Linux", platform.machine(), dir_name, name))
if exe and os.path.isfile(exe):
return exe
else:
raise BinaryNotFound
def cache_maintenance():

Binary file not shown.

View file

Binary file not shown.

View file

Binary file not shown.

19
libs/auditok/__init__.py Normal file
View file

@ -0,0 +1,19 @@
"""
:author:
Amine SEHILI <amine.sehili@gmail.com>
2015-2018
:License:
This package is published under GNU GPL Version 3.
"""
from __future__ import absolute_import
from .core import *
from .io import *
from .util import *
from . import dataset
from .exceptions import *
__version__ = "0.1.8"

794
libs/auditok/cmdline.py Normal file
View file

@ -0,0 +1,794 @@
#!/usr/bin/env python
# encoding: utf-8
'''
auditok.auditok -- Audio Activity Detection tool
auditok.auditok is a program that can be used for Audio/Acoustic activity detection.
It can read audio data from audio files as well as from built-in device(s) or standard input
@author: Mohamed El Amine SEHILI
@copyright: 2015-2018 Mohamed El Amine SEHILI
@license: GPL v3
@contact: amine.sehili@gmail.com
@deffield updated: 01 Nov 2018
'''
import sys
import os
from optparse import OptionParser, OptionGroup
from threading import Thread
import tempfile
import wave
import time
import threading
import logging
try:
import future
from queue import Queue, Empty
except ImportError:
if sys.version_info >= (3, 0):
from queue import Queue, Empty
else:
from Queue import Queue, Empty
try:
from pydub import AudioSegment
WITH_PYDUB = True
except ImportError:
WITH_PYDUB = False
from .core import StreamTokenizer
from .io import PyAudioSource, BufferAudioSource, StdinAudioSource, player_for
from .util import ADSFactory, AudioEnergyValidator
from auditok import __version__ as version
__all__ = []
__version__ = version
__date__ = '2015-11-23'
__updated__ = '2018-10-06'
DEBUG = 0
TESTRUN = 1
PROFILE = 0
LOGGER_NAME = "AUDITOK_LOGGER"
class AudioFileFormatError(Exception):
pass
class TimeFormatError(Exception):
pass
def file_to_audio_source(filename, filetype=None, **kwargs):
lower_fname = filename.lower()
rawdata = False
if filetype is not None:
filetype = filetype.lower()
if filetype == "raw" or (filetype is None and lower_fname.endswith(".raw")):
srate = kwargs.pop("sampling_rate", None)
if srate is None:
srate = kwargs.pop("sr", None)
swidth = kwargs.pop("sample_width", None)
if swidth is None:
swidth = kwargs.pop("sw", None)
ch = kwargs.pop("channels", None)
if ch is None:
ch = kwargs.pop("ch", None)
if None in (swidth, srate, ch):
raise Exception("All audio parameters are required for raw data")
data = open(filename).read()
rawdata = True
# try first with pydub
if WITH_PYDUB:
use_channel = kwargs.pop("use_channel", None)
if use_channel is None:
use_channel = kwargs.pop("uc", None)
if use_channel is None:
use_channel = 1
else:
try:
use_channel = int(use_channel)
except ValueError:
pass
if not isinstance(use_channel, (int)) and not use_channel.lower() in ["left", "right", "mix"] :
raise ValueError("channel must be an integer or one of 'left', 'right' or 'mix'")
asegment = None
if rawdata:
asegment = AudioSegment(data, sample_width=swidth, frame_rate=srate, channels=ch)
if filetype in("wave", "wav") or (filetype is None and lower_fname.endswith(".wav")):
asegment = AudioSegment.from_wav(filename)
elif filetype == "mp3" or (filetype is None and lower_fname.endswith(".mp3")):
asegment = AudioSegment.from_mp3(filename)
elif filetype == "ogg" or (filetype is None and lower_fname.endswith(".ogg")):
asegment = AudioSegment.from_ogg(filename)
elif filetype == "flv" or (filetype is None and lower_fname.endswith(".flv")):
asegment = AudioSegment.from_flv(filename)
else:
asegment = AudioSegment.from_file(filename)
if asegment.channels > 1:
if isinstance(use_channel, int):
if use_channel > asegment.channels:
raise ValueError("Can not use channel '{0}', audio file has only {1} channels".format(use_channel, asegment.channels))
else:
asegment = asegment.split_to_mono()[use_channel - 1]
else:
ch_lower = use_channel.lower()
if ch_lower == "mix":
asegment = asegment.set_channels(1)
elif use_channel.lower() == "left":
asegment = asegment.split_to_mono()[0]
elif use_channel.lower() == "right":
asegment = asegment.split_to_mono()[1]
return BufferAudioSource(data_buffer = asegment._data,
sampling_rate = asegment.frame_rate,
sample_width = asegment.sample_width,
channels = asegment.channels)
# fall back to standard python
else:
if rawdata:
if ch != 1:
raise ValueError("Cannot handle multi-channel audio without pydub")
return BufferAudioSource(data, srate, swidth, ch)
if filetype in ("wav", "wave") or (filetype is None and lower_fname.endswith(".wav")):
wfp = wave.open(filename)
ch = wfp.getnchannels()
if ch != 1:
wfp.close()
raise ValueError("Cannot handle multi-channel audio without pydub")
srate = wfp.getframerate()
swidth = wfp.getsampwidth()
data = wfp.readframes(wfp.getnframes())
wfp.close()
return BufferAudioSource(data, srate, swidth, ch)
raise AudioFileFormatError("Cannot read audio file format")
def save_audio_data(data, filename, filetype=None, **kwargs):
lower_fname = filename.lower()
if filetype is not None:
filetype = filetype.lower()
# save raw data
if filetype == "raw" or (filetype is None and lower_fname.endswith(".raw")):
fp = open(filename, "w")
fp.write(data)
fp.close()
return
# save other types of data
# requires all audio parameters
srate = kwargs.pop("sampling_rate", None)
if srate is None:
srate = kwargs.pop("sr", None)
swidth = kwargs.pop("sample_width", None)
if swidth is None:
swidth = kwargs.pop("sw", None)
ch = kwargs.pop("channels", None)
if ch is None:
ch = kwargs.pop("ch", None)
if None in (swidth, srate, ch):
raise Exception("All audio parameters are required to save no raw data")
if filetype in ("wav", "wave") or (filetype is None and lower_fname.endswith(".wav")):
# use standard python's wave module
fp = wave.open(filename, "w")
fp.setnchannels(ch)
fp.setsampwidth(swidth)
fp.setframerate(srate)
fp.writeframes(data)
fp.close()
elif WITH_PYDUB:
asegment = AudioSegment(data, sample_width=swidth, frame_rate=srate, channels=ch)
asegment.export(filename, format=filetype)
else:
raise AudioFileFormatError("cannot write file format {0} (file name: {1})".format(filetype, filename))
def plot_all(signal, sampling_rate, energy_as_amp, detections=[], show=True, save_as=None):
import matplotlib.pyplot as plt
import numpy as np
t = np.arange(0., np.ceil(float(len(signal))) / sampling_rate, 1./sampling_rate )
if len(t) > len(signal):
t = t[: len(signal) - len(t)]
for start, end in detections:
p = plt.axvspan(start, end, facecolor='g', ec = 'r', lw = 2, alpha=0.4)
line = plt.axhline(y=energy_as_amp, lw=1, ls="--", c="r", label="Energy threshold as normalized amplitude")
plt.plot(t, signal)
legend = plt.legend(["Detection threshold"], bbox_to_anchor=(0., 1.02, 1., .102), loc=1, fontsize=16)
ax = plt.gca().add_artist(legend)
plt.xlabel("Time (s)", fontsize=24)
plt.ylabel("Amplitude (normalized)", fontsize=24)
if save_as is not None:
plt.savefig(save_as, dpi=120)
if show:
plt.show()
def seconds_to_str_fromatter(_format):
"""
Accepted format directives: %i %s %m %h
"""
# check directives are correct
if _format == "%S":
def _fromatter(seconds):
return "{:.2f}".format(seconds)
elif _format == "%I":
def _fromatter(seconds):
return "{0}".format(int(seconds * 1000))
else:
_format = _format.replace("%h", "{hrs:02d}")
_format = _format.replace("%m", "{mins:02d}")
_format = _format.replace("%s", "{secs:02d}")
_format = _format.replace("%i", "{millis:03d}")
try:
i = _format.index("%")
raise TimeFormatError("Unknow time format directive '{0}'".format(_format[i:i+2]))
except ValueError:
pass
def _fromatter(seconds):
millis = int(seconds * 1000)
hrs, millis = divmod(millis, 3600000)
mins, millis = divmod(millis, 60000)
secs, millis = divmod(millis, 1000)
return _format.format(hrs=hrs, mins=mins, secs=secs, millis=millis)
return _fromatter
class Worker(Thread):
def __init__(self, timeout=0.2, debug=False, logger=None):
self.timeout = timeout
self.debug = debug
self.logger = logger
if self.debug and self.logger is None:
self.logger = logging.getLogger(LOGGER_NAME)
self.logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
self.logger.addHandler(handler)
self._inbox = Queue()
self._stop_request = Queue()
Thread.__init__(self)
def debug_message(self, message):
self.logger.debug(message)
def _stop_requested(self):
try:
message = self._stop_request.get_nowait()
if message == "stop":
return True
except Empty:
return False
def stop(self):
self._stop_request.put("stop")
self.join()
def send(self, message):
self._inbox.put(message)
def _get_message(self):
try:
message = self._inbox.get(timeout=self.timeout)
return message
except Empty:
return None
class TokenizerWorker(Worker):
END_OF_PROCESSING = "END_OF_PROCESSING"
def __init__(self, ads, tokenizer, analysis_window, observers):
self.ads = ads
self.tokenizer = tokenizer
self.analysis_window = analysis_window
self.observers = observers
self._inbox = Queue()
self.count = 0
Worker.__init__(self)
def run(self):
def notify_observers(data, start, end):
audio_data = b''.join(data)
self.count += 1
start_time = start * self.analysis_window
end_time = (end+1) * self.analysis_window
duration = (end - start + 1) * self.analysis_window
# notify observers
for observer in self.observers:
observer.notify({"id" : self.count,
"audio_data" : audio_data,
"start" : start,
"end" : end,
"start_time" : start_time,
"end_time" : end_time,
"duration" : duration}
)
self.ads.open()
self.tokenizer.tokenize(data_source=self, callback=notify_observers)
for observer in self.observers:
observer.notify(TokenizerWorker.END_OF_PROCESSING)
def add_observer(self, observer):
self.observers.append(observer)
def remove_observer(self, observer):
self.observers.remove(observer)
def read(self):
if self._stop_requested():
return None
else:
return self.ads.read()
class PlayerWorker(Worker):
def __init__(self, player, timeout=0.2, debug=False, logger=None):
self.player = player
Worker.__init__(self, timeout=timeout, debug=debug, logger=logger)
def run(self):
while True:
if self._stop_requested():
break
message = self._get_message()
if message is not None:
if message == TokenizerWorker.END_OF_PROCESSING:
break
audio_data = message.pop("audio_data", None)
start_time = message.pop("start_time", None)
end_time = message.pop("end_time", None)
dur = message.pop("duration", None)
_id = message.pop("id", None)
if audio_data is not None:
if self.debug:
self.debug_message("[PLAY]: Detection {id} played (start:{start}, end:{end}, dur:{dur})".format(id=_id,
start="{:5.2f}".format(start_time), end="{:5.2f}".format(end_time), dur="{:5.2f}".format(dur)))
self.player.play(audio_data)
def notify(self, message):
self.send(message)
class CommandLineWorker(Worker):
def __init__(self, command, timeout=0.2, debug=False, logger=None):
self.command = command
Worker.__init__(self, timeout=timeout, debug=debug, logger=logger)
def run(self):
while True:
if self._stop_requested():
break
message = self._get_message()
if message is not None:
if message == TokenizerWorker.END_OF_PROCESSING:
break
audio_data = message.pop("audio_data", None)
_id = message.pop("id", None)
if audio_data is not None:
raw_audio_file = tempfile.NamedTemporaryFile(delete=False)
raw_audio_file.write(audio_data)
cmd = self.command.replace("$", raw_audio_file.name)
if self.debug:
self.debug_message("[CMD ]: Detection {id} command: {cmd}".format(id=_id, cmd=cmd))
os.system(cmd)
os.unlink(raw_audio_file.name)
def notify(self, message):
self.send(message)
class TokenSaverWorker(Worker):
def __init__(self, name_format, filetype, timeout=0.2, debug=False, logger=None, **kwargs):
self.name_format = name_format
self.filetype = filetype
self.kwargs = kwargs
Worker.__init__(self, timeout=timeout, debug=debug, logger=logger)
def run(self):
while True:
if self._stop_requested():
break
message = self._get_message()
if message is not None:
if message == TokenizerWorker.END_OF_PROCESSING:
break
audio_data = message.pop("audio_data", None)
start_time = message.pop("start_time", None)
end_time = message.pop("end_time", None)
_id = message.pop("id", None)
if audio_data is not None and len(audio_data) > 0:
fname = self.name_format.format(N=_id, start = "{:.2f}".format(start_time), end = "{:.2f}".format(end_time))
try:
if self.debug:
self.debug_message("[SAVE]: Detection {id} saved as {fname}".format(id=_id, fname=fname))
save_audio_data(audio_data, fname, filetype=self.filetype, **self.kwargs)
except Exception as e:
sys.stderr.write(str(e) + "\n")
def notify(self, message):
self.send(message)
class LogWorker(Worker):
def __init__(self, print_detections=False, output_format="{start} {end}",
time_formatter=seconds_to_str_fromatter("%S"), timeout=0.2, debug=False, logger=None):
self.print_detections = print_detections
self.output_format = output_format
self.time_formatter = time_formatter
self.detections = []
Worker.__init__(self, timeout=timeout, debug=debug, logger=logger)
def run(self):
while True:
if self._stop_requested():
break
message = self._get_message()
if message is not None:
if message == TokenizerWorker.END_OF_PROCESSING:
break
audio_data = message.pop("audio_data", None)
_id = message.pop("id", None)
start = message.pop("start", None)
end = message.pop("end", None)
start_time = message.pop("start_time", None)
end_time = message.pop("end_time", None)
duration = message.pop("duration", None)
if audio_data is not None and len(audio_data) > 0:
if self.debug:
self.debug_message("[DET ]: Detection {id} (start:{start}, end:{end})".format(id=_id,
start="{:5.2f}".format(start_time),
end="{:5.2f}".format(end_time)))
if self.print_detections:
print(self.output_format.format(id = _id,
start = self.time_formatter(start_time),
end = self.time_formatter(end_time), duration = self.time_formatter(duration)))
self.detections.append((_id, start, end, start_time, end_time))
def notify(self, message):
self.send(message)
def main(argv=None):
'''Command line options.'''
program_name = os.path.basename(sys.argv[0])
program_version = version
program_build_date = "%s" % __updated__
program_version_string = '%%prog %s (%s)' % (program_version, program_build_date)
#program_usage = '''usage: spam two eggs''' # optional - will be autogenerated by optparse
program_longdesc = '''''' # optional - give further explanation about what the program does
program_license = "Copyright 2015-2018 Mohamed El Amine SEHILI \
Licensed under the General Public License (GPL) Version 3 \nhttp://www.gnu.org/licenses/"
if argv is None:
argv = sys.argv[1:]
try:
# setup option parser
parser = OptionParser(version=program_version_string, epilog=program_longdesc, description=program_license)
group = OptionGroup(parser, "[Input-Output options]")
group.add_option("-i", "--input", dest="input", help="Input audio or video file. Use - for stdin [default: read from microphone using pyaudio]", metavar="FILE")
group.add_option("-t", "--input-type", dest="input_type", help="Input audio file type. Mandatory if file name has no extension [default: %default]", type=str, default=None, metavar="String")
group.add_option("-M", "--max_time", dest="max_time", help="Max data (in seconds) to read from microphone/file [default: read until the end of file/stream]", type=float, default=None, metavar="FLOAT")
group.add_option("-O", "--output-main", dest="output_main", help="Save main stream as. If omitted main stream will not be saved [default: omitted]", type=str, default=None, metavar="FILE")
group.add_option("-o", "--output-tokens", dest="output_tokens", help="Output file name format for detections. Use {N} and {start} and {end} to build file names, example: 'Det_{N}_{start}-{end}.wav'", type=str, default=None, metavar="STRING")
group.add_option("-T", "--output-type", dest="output_type", help="Audio type used to save detections and/or main stream. If not supplied will: (1). guess from extension or (2). use wav format", type=str, default=None, metavar="STRING")
group.add_option("-u", "--use-channel", dest="use_channel", help="Choose channel to use from a multi-channel audio file (requires pydub). 'left', 'right' and 'mix' are accepted values. [Default: 1 (i.e. 1st or left channel)]", type=str, default="1", metavar="STRING")
parser.add_option_group(group)
group = OptionGroup(parser, "[Tokenization options]", "Set tokenizer options and energy threshold.")
group.add_option("-a", "--analysis-window", dest="analysis_window", help="Size of analysis window in seconds [default: %default (10ms)]", type=float, default=0.01, metavar="FLOAT")
group.add_option("-n", "--min-duration", dest="min_duration", help="Min duration of a valid audio event in seconds [default: %default]", type=float, default=0.2, metavar="FLOAT")
group.add_option("-m", "--max-duration", dest="max_duration", help="Max duration of a valid audio event in seconds [default: %default]", type=float, default=5, metavar="FLOAT")
group.add_option("-s", "--max-silence", dest="max_silence", help="Max duration of a consecutive silence within a valid audio event in seconds [default: %default]", type=float, default=0.3, metavar="FLOAT")
group.add_option("-d", "--drop-trailing-silence", dest="drop_trailing_silence", help="Drop trailing silence from a detection [default: keep trailing silence]", action="store_true", default=False)
group.add_option("-e", "--energy-threshold", dest="energy_threshold", help="Log energy threshold for detection [default: %default]", type=float, default=50, metavar="FLOAT")
parser.add_option_group(group)
group = OptionGroup(parser, "[Audio parameters]", "Define audio parameters if data is read from a headerless file (raw or stdin) or you want to use different microphone parameters.")
group.add_option("-r", "--rate", dest="sampling_rate", help="Sampling rate of audio data [default: %default]", type=int, default=16000, metavar="INT")
group.add_option("-c", "--channels", dest="channels", help="Number of channels of audio data [default: %default]", type=int, default=1, metavar="INT")
group.add_option("-w", "--width", dest="sample_width", help="Number of bytes per audio sample [default: %default]", type=int, default=2, metavar="INT")
group.add_option("-I", "--input-device-index", dest="input_device_index", help="Audio device index [default: %default] - only when using PyAudio", type=int, default=None, metavar="INT")
group.add_option("-F", "--audio-frame-per-buffer", dest="frame_per_buffer", help="Audio frame per buffer [default: %default] - only when using PyAudio", type=int, default=1024, metavar="INT")
parser.add_option_group(group)
group = OptionGroup(parser, "[Do something with detections]", "Use these options to print, play or plot detections.")
group.add_option("-C", "--command", dest="command", help="Command to call when an audio detection occurs. Use $ to represent the file name to use with the command (e.g. -C 'du -h $')", default=None, type=str, metavar="STRING")
group.add_option("-E", "--echo", dest="echo", help="Play back each detection immediately using pyaudio [default: do not play]", action="store_true", default=False)
group.add_option("-p", "--plot", dest="plot", help="Plot and show audio signal and detections (requires matplotlib)", action="store_true", default=False)
group.add_option("", "--save-image", dest="save_image", help="Save plotted audio signal and detections as a picture or a PDF file (requires matplotlib)", type=str, default=None, metavar="FILE")
group.add_option("", "--printf", dest="printf", help="print detections, one per line, using a user supplied format (e.g. '[{id}]: {start} -- {end}'). Available keywords {id}, {start}, {end} and {duration}", type=str, default="{id} {start} {end}", metavar="STRING")
group.add_option("", "--time-format", dest="time_format", help="format used to print {start} and {end}. [Default= %default]. %S: absolute time in sec. %I: absolute time in ms. If at least one of (%h, %m, %s, %i) is used, convert time into hours, minutes, seconds and millis (e.g. %h:%m:%s.%i). Only required fields are printed", type=str, default="%S", metavar="STRING")
parser.add_option_group(group)
parser.add_option("-q", "--quiet", dest="quiet", help="Do not print any information about detections [default: print 'id', 'start' and 'end' of each detection]", action="store_true", default=False)
parser.add_option("-D", "--debug", dest="debug", help="Print processing operations to STDOUT", action="store_true", default=False)
parser.add_option("", "--debug-file", dest="debug_file", help="Print processing operations to FILE", type=str, default=None, metavar="FILE")
# process options
(opts, args) = parser.parse_args(argv)
if opts.input == "-":
asource = StdinAudioSource(sampling_rate = opts.sampling_rate,
sample_width = opts.sample_width,
channels = opts.channels)
#read data from a file
elif opts.input is not None:
asource = file_to_audio_source(filename=opts.input, filetype=opts.input_type, uc=opts.use_channel)
# read data from microphone via pyaudio
else:
try:
asource = PyAudioSource(sampling_rate = opts.sampling_rate,
sample_width = opts.sample_width,
channels = opts.channels,
frames_per_buffer = opts.frame_per_buffer,
input_device_index = opts.input_device_index)
except Exception:
sys.stderr.write("Cannot read data from audio device!\n")
sys.stderr.write("You should either install pyaudio or read data from STDIN\n")
sys.exit(2)
logger = logging.getLogger(LOGGER_NAME)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
if opts.quiet or not opts.debug:
# only critical messages will be printed
handler.setLevel(logging.CRITICAL)
else:
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)
if opts.debug_file is not None:
logger.setLevel(logging.DEBUG)
opts.debug = True
handler = logging.FileHandler(opts.debug_file, "w")
fmt = logging.Formatter('[%(asctime)s] | %(message)s')
handler.setFormatter(fmt)
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)
record = opts.output_main is not None or opts.plot or opts.save_image is not None
ads = ADSFactory.ads(audio_source = asource, block_dur = opts.analysis_window, max_time = opts.max_time, record = record)
validator = AudioEnergyValidator(sample_width=asource.get_sample_width(), energy_threshold=opts.energy_threshold)
if opts.drop_trailing_silence:
mode = StreamTokenizer.DROP_TRAILING_SILENCE
else:
mode = 0
analysis_window_per_second = 1. / opts.analysis_window
tokenizer = StreamTokenizer(validator=validator, min_length=opts.min_duration * analysis_window_per_second,
max_length=int(opts.max_duration * analysis_window_per_second),
max_continuous_silence=opts.max_silence * analysis_window_per_second,
mode = mode)
observers = []
tokenizer_worker = None
if opts.output_tokens is not None:
try:
# check user format is correct
fname = opts.output_tokens.format(N=0, start=0, end=0)
# find file type for detections
tok_type = opts.output_type
if tok_type is None:
tok_type = os.path.splitext(opts.output_tokens)[1][1:]
if tok_type == "":
tok_type = "wav"
token_saver = TokenSaverWorker(name_format=opts.output_tokens, filetype=tok_type,
debug=opts.debug, logger=logger, sr=asource.get_sampling_rate(),
sw=asource.get_sample_width(),
ch=asource.get_channels())
observers.append(token_saver)
except Exception:
sys.stderr.write("Wrong format for detections file name: '{0}'\n".format(opts.output_tokens))
sys.exit(2)
if opts.echo:
try:
player = player_for(asource)
player_worker = PlayerWorker(player=player, debug=opts.debug, logger=logger)
observers.append(player_worker)
except Exception:
sys.stderr.write("Cannot get an audio player!\n")
sys.stderr.write("You should either install pyaudio or supply a command (-C option) to play audio\n")
sys.exit(2)
if opts.command is not None and len(opts.command) > 0:
cmd_worker = CommandLineWorker(command=opts.command, debug=opts.debug, logger=logger)
observers.append(cmd_worker)
if not opts.quiet or opts.plot is not None or opts.save_image is not None:
oformat = opts.printf.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
converter = seconds_to_str_fromatter(opts.time_format)
log_worker = LogWorker(print_detections = not opts.quiet, output_format=oformat,
time_formatter=converter, logger=logger, debug=opts.debug)
observers.append(log_worker)
tokenizer_worker = TokenizerWorker(ads, tokenizer, opts.analysis_window, observers)
def _save_main_stream():
# find file type
main_type = opts.output_type
if main_type is None:
main_type = os.path.splitext(opts.output_main)[1][1:]
if main_type == "":
main_type = "wav"
ads.close()
ads.rewind()
data = ads.get_audio_source().get_data_buffer()
if len(data) > 0:
save_audio_data(data=data, filename=opts.output_main, filetype=main_type, sr=asource.get_sampling_rate(),
sw = asource.get_sample_width(),
ch = asource.get_channels())
def _plot():
import numpy as np
ads.close()
ads.rewind()
data = ads.get_audio_source().get_data_buffer()
signal = AudioEnergyValidator._convert(data, asource.get_sample_width())
detections = [(det[3] , det[4]) for det in log_worker.detections]
max_amplitude = 2**(asource.get_sample_width() * 8 - 1) - 1
energy_as_amp = np.sqrt(np.exp(opts.energy_threshold * np.log(10) / 10)) / max_amplitude
plot_all(signal / max_amplitude, asource.get_sampling_rate(), energy_as_amp, detections, show = opts.plot, save_as = opts.save_image)
# start observer threads
for obs in observers:
obs.start()
# start tokenization thread
tokenizer_worker.start()
while True:
time.sleep(1)
if len(threading.enumerate()) == 1:
break
tokenizer_worker = None
if opts.output_main is not None:
_save_main_stream()
if opts.plot or opts.save_image is not None:
_plot()
return 0
except KeyboardInterrupt:
if tokenizer_worker is not None:
tokenizer_worker.stop()
for obs in observers:
obs.stop()
if opts.output_main is not None:
_save_main_stream()
if opts.plot or opts.save_image is not None:
_plot()
return 0
except Exception as e:
sys.stderr.write(program_name + ": " + str(e) + "\n")
sys.stderr.write("for help use -h\n")
return 2
if __name__ == "__main__":
if DEBUG:
sys.argv.append("-h")
if TESTRUN:
import doctest
doctest.testmod()
if PROFILE:
import cProfile
import pstats
profile_filename = 'auditok.auditok_profile.txt'
cProfile.run('main()', profile_filename)
statsfile = open("profile_stats.txt", "wb")
p = pstats.Stats(profile_filename, stream=statsfile)
stats = p.strip_dirs().sort_stats('cumulative')
stats.print_stats()
statsfile.close()
sys.exit(0)
sys.exit(main())

437
libs/auditok/core.py Normal file
View file

@ -0,0 +1,437 @@
"""
This module gathers processing (i.e. tokenization) classes.
Class summary
=============
.. autosummary::
StreamTokenizer
"""
from auditok.util import DataValidator
__all__ = ["StreamTokenizer"]
class StreamTokenizer():
"""
Class for stream tokenizers. It implements a 4-state automaton scheme
to extract sub-sequences of interest on the fly.
:Parameters:
`validator` :
instance of `DataValidator` that implements `is_valid` method.
`min_length` : *(int)*
Minimum number of frames of a valid token. This includes all \
tolerated non valid frames within the token.
`max_length` : *(int)*
Maximum number of frames of a valid token. This includes all \
tolerated non valid frames within the token.
`max_continuous_silence` : *(int)*
Maximum number of consecutive non-valid frames within a token.
Note that, within a valid token, there may be many tolerated \
*silent* regions that contain each a number of non valid frames up to \
`max_continuous_silence`
`init_min` : *(int, default=0)*
Minimum number of consecutive valid frames that must be **initially** \
gathered before any sequence of non valid frames can be tolerated. This
option is not always needed, it can be used to drop non-valid tokens as
early as possible. **Default = 0** means that the option is by default
ineffective.
`init_max_silence` : *(int, default=0)*
Maximum number of tolerated consecutive non-valid frames if the \
number already gathered valid frames has not yet reached 'init_min'.
This argument is normally used if `init_min` is used. **Default = 0**,
by default this argument is not taken into consideration.
`mode` : *(int, default=0)*
`mode` can be:
1. `StreamTokenizer.STRICT_MIN_LENGTH`:
if token *i* is delivered because `max_length`
is reached, and token *i+1* is immediately adjacent to
token *i* (i.e. token *i* ends at frame *k* and token *i+1* starts
at frame *k+1*) then accept token *i+1* only of it has a size of at
least `min_length`. The default behavior is to accept token *i+1*
event if it is shorter than `min_length` (given that the above conditions
are fulfilled of course).
:Examples:
In the following code, without `STRICT_MIN_LENGTH`, the 'BB' token is
accepted although it is shorter than `min_length` (3), because it immediately
follows the latest delivered token:
.. code:: python
from auditok import StreamTokenizer, StringDataSource, DataValidator
class UpperCaseChecker(DataValidator):
def is_valid(self, frame):
return frame.isupper()
dsource = StringDataSource("aaaAAAABBbbb")
tokenizer = StreamTokenizer(validator=UpperCaseChecker(),
min_length=3,
max_length=4,
max_continuous_silence=0)
tokenizer.tokenize(dsource)
:output:
.. code:: python
[(['A', 'A', 'A', 'A'], 3, 6), (['B', 'B'], 7, 8)]
The following tokenizer will however reject the 'BB' token:
.. code:: python
dsource = StringDataSource("aaaAAAABBbbb")
tokenizer = StreamTokenizer(validator=UpperCaseChecker(),
min_length=3, max_length=4,
max_continuous_silence=0,
mode=StreamTokenizer.STRICT_MIN_LENGTH)
tokenizer.tokenize(dsource)
:output:
.. code:: python
[(['A', 'A', 'A', 'A'], 3, 6)]
2. `StreamTokenizer.DROP_TRAILING_SILENCE`: drop all tailing non-valid frames
from a token to be delivered if and only if it is not **truncated**.
This can be a bit tricky. A token is actually delivered if:
- a. `max_continuous_silence` is reached
:or:
- b. Its length reaches `max_length`. This is called a **truncated** token
In the current implementation, a `StreamTokenizer`'s decision is only based on already seen
data and on incoming data. Thus, if a token is truncated at a non-valid but tolerated
frame (`max_length` is reached but `max_continuous_silence` not yet) any tailing
silence will be kept because it can potentially be part of valid token (if `max_length`
was bigger). But if `max_continuous_silence` is reached before `max_length`, the delivered
token will not be considered as truncated but a result of *normal* end of detection
(i.e. no more valid data). In that case the tailing silence can be removed if you use
the `StreamTokenizer.DROP_TRAILING_SILENCE` mode.
:Example:
.. code:: python
tokenizer = StreamTokenizer(validator=UpperCaseChecker(), min_length=3,
max_length=6, max_continuous_silence=3,
mode=StreamTokenizer.DROP_TRAILING_SILENCE)
dsource = StringDataSource("aaaAAAaaaBBbbbb")
tokenizer.tokenize(dsource)
:output:
.. code:: python
[(['A', 'A', 'A', 'a', 'a', 'a'], 3, 8), (['B', 'B'], 9, 10)]
The first token is delivered with its tailing silence because it is truncated
while the second one has its tailing frames removed.
Without `StreamTokenizer.DROP_TRAILING_SILENCE` the output would be:
.. code:: python
[(['A', 'A', 'A', 'a', 'a', 'a'], 3, 8), (['B', 'B', 'b', 'b', 'b'], 9, 13)]
3. `StreamTokenizer.STRICT_MIN_LENGTH | StreamTokenizer.DROP_TRAILING_SILENCE`:
use both options. That means: first remove tailing silence, then ckeck if the
token still has at least a length of `min_length`.
"""
SILENCE = 0
POSSIBLE_SILENCE = 1
POSSIBLE_NOISE = 2
NOISE = 3
STRICT_MIN_LENGTH = 2
DROP_TRAILING_SILENCE = 4
# alias
DROP_TAILING_SILENCE = 4
def __init__(self, validator,
min_length, max_length, max_continuous_silence,
init_min=0, init_max_silence=0,
mode=0):
if not isinstance(validator, DataValidator):
raise TypeError("'validator' must be an instance of 'DataValidator'")
if max_length <= 0:
raise ValueError("'max_length' must be > 0 (value={0})".format(max_length))
if min_length <= 0 or min_length > max_length:
raise ValueError("'min_length' must be > 0 and <= 'max_length' (value={0})".format(min_length))
if max_continuous_silence >= max_length:
raise ValueError("'max_continuous_silence' must be < 'max_length' (value={0})".format(max_continuous_silence))
if init_min >= max_length:
raise ValueError("'init_min' must be < 'max_length' (value={0})".format(max_continuous_silence))
self.validator = validator
self.min_length = min_length
self.max_length = max_length
self.max_continuous_silence = max_continuous_silence
self.init_min = init_min
self.init_max_silent = init_max_silence
self._mode = None
self.set_mode(mode)
self._strict_min_length = (mode & self.STRICT_MIN_LENGTH) != 0
self._drop_tailing_silence = (mode & self.DROP_TRAILING_SILENCE) != 0
self._deliver = None
self._tokens = None
self._state = None
self._data = None
self._contiguous_token = False
self._init_count = 0
self._silence_length = 0
self._start_frame = 0
self._current_frame = 0
def set_mode(self, mode):
"""
:Parameters:
`mode` : *(int)*
New mode, must be one of:
- `StreamTokenizer.STRICT_MIN_LENGTH`
- `StreamTokenizer.DROP_TRAILING_SILENCE`
- `StreamTokenizer.STRICT_MIN_LENGTH | StreamTokenizer.DROP_TRAILING_SILENCE`
- `0`
See `StreamTokenizer.__init__` for more information about the mode.
"""
if not mode in [self.STRICT_MIN_LENGTH, self.DROP_TRAILING_SILENCE,
self.STRICT_MIN_LENGTH | self.DROP_TRAILING_SILENCE, 0]:
raise ValueError("Wrong value for mode")
self._mode = mode
self._strict_min_length = (mode & self.STRICT_MIN_LENGTH) != 0
self._drop_tailing_silence = (mode & self.DROP_TRAILING_SILENCE) != 0
def get_mode(self):
"""
Return the current mode. To check whether a specific mode is activated use
the bitwise 'and' operator `&`. Example:
.. code:: python
if mode & self.STRICT_MIN_LENGTH != 0:
do_something()
"""
return self._mode
def _reinitialize(self):
self._contiguous_token = False
self._data = []
self._tokens = []
self._state = self.SILENCE
self._current_frame = -1
self._deliver = self._append_token
def tokenize(self, data_source, callback=None):
"""
Read data from `data_source`, one frame a time, and process the read frames in
order to detect sequences of frames that make up valid tokens.
:Parameters:
`data_source` : instance of the :class:`DataSource` class that implements a `read` method.
'read' should return a slice of signal, i.e. frame (of whatever \
type as long as it can be processed by validator) and None if \
there is no more signal.
`callback` : an optional 3-argument function.
If a `callback` function is given, it will be called each time a valid token
is found.
:Returns:
A list of tokens if `callback` is None. Each token is tuple with the following elements:
.. code python
(data, start, end)
where `data` is a list of read frames, `start`: index of the first frame in the
original data and `end` : index of the last frame.
"""
self._reinitialize()
if callback is not None:
self._deliver = callback
while True:
frame = data_source.read()
if frame is None:
break
self._current_frame += 1
self._process(frame)
self._post_process()
if callback is None:
_ret = self._tokens
self._tokens = None
return _ret
def _process(self, frame):
frame_is_valid = self.validator.is_valid(frame)
if self._state == self.SILENCE:
if frame_is_valid:
# seems we got a valid frame after a silence
self._init_count = 1
self._silence_length = 0
self._start_frame = self._current_frame
self._data.append(frame)
if self._init_count >= self.init_min:
self._state = self.NOISE
if len(self._data) >= self.max_length:
self._process_end_of_detection(True)
else:
self._state = self.POSSIBLE_NOISE
elif self._state == self.POSSIBLE_NOISE:
if frame_is_valid:
self._silence_length = 0
self._init_count += 1
self._data.append(frame)
if self._init_count >= self.init_min:
self._state = self.NOISE
if len(self._data) >= self.max_length:
self._process_end_of_detection(True)
else:
self._silence_length += 1
if self._silence_length > self.init_max_silent or \
len(self._data) + 1 >= self.max_length:
# either init_max_silent or max_length is reached
# before _init_count, back to silence
self._data = []
self._state = self.SILENCE
else:
self._data.append(frame)
elif self._state == self.NOISE:
if frame_is_valid:
self._data.append(frame)
if len(self._data) >= self.max_length:
self._process_end_of_detection(True)
elif self.max_continuous_silence <= 0:
# max token reached at this frame will _deliver if _contiguous_token
# and not _strict_min_length
self._process_end_of_detection()
self._state = self.SILENCE
else:
# this is the first silent frame following a valid one
# and it is tolerated
self._silence_length = 1
self._data.append(frame)
self._state = self.POSSIBLE_SILENCE
if len(self._data) == self.max_length:
self._process_end_of_detection(True)
# don't reset _silence_length because we still
# need to know the total number of silent frames
elif self._state == self.POSSIBLE_SILENCE:
if frame_is_valid:
self._data.append(frame)
self._silence_length = 0
self._state = self.NOISE
if len(self._data) >= self.max_length:
self._process_end_of_detection(True)
else:
if self._silence_length >= self.max_continuous_silence:
if self._silence_length < len(self._data):
# _deliver only gathered frames aren't all silent
self._process_end_of_detection()
else:
self._data = []
self._state = self.SILENCE
self._silence_length = 0
else:
self._data.append(frame)
self._silence_length += 1
if len(self._data) >= self.max_length:
self._process_end_of_detection(True)
# don't reset _silence_length because we still
# need to know the total number of silent frames
def _post_process(self):
if self._state == self.NOISE or self._state == self.POSSIBLE_SILENCE:
if len(self._data) > 0 and len(self._data) > self._silence_length:
self._process_end_of_detection()
def _process_end_of_detection(self, truncated=False):
if not truncated and self._drop_tailing_silence and self._silence_length > 0:
# happens if max_continuous_silence is reached
# or max_length is reached at a silent frame
self._data = self._data[0: - self._silence_length]
if (len(self._data) >= self.min_length) or \
(len(self._data) > 0 and
not self._strict_min_length and self._contiguous_token):
_end_frame = self._start_frame + len(self._data) - 1
self._deliver(self._data, self._start_frame, _end_frame)
if truncated:
# next token (if any) will start at _current_frame + 1
self._start_frame = self._current_frame + 1
# remember that it is contiguous with the just delivered one
self._contiguous_token = True
else:
self._contiguous_token = False
else:
self._contiguous_token = False
self._data = []
def _append_token(self, data, start, end):
self._tokens.append((data, start, end))

Binary file not shown.

18
libs/auditok/dataset.py Normal file
View file

@ -0,0 +1,18 @@
"""
This module contains links to audio files you can use for test purposes.
"""
import os
__all__ = ["one_to_six_arabic_16000_mono_bc_noise", "was_der_mensch_saet_mono_44100_lead_trail_silence"]
_current_dir = os.path.dirname(os.path.realpath(__file__))
one_to_six_arabic_16000_mono_bc_noise = "{cd}{sep}data{sep}1to6arabic_\
16000_mono_bc_noise.wav".format(cd=_current_dir, sep=os.path.sep)
"""A wave file that contains a pronunciation of Arabic numbers from 1 to 6"""
was_der_mensch_saet_mono_44100_lead_trail_silence = "{cd}{sep}data{sep}was_\
der_mensch_saet_das_wird_er_vielfach_ernten_44100Hz_mono_lead_trail_\
silence.wav".format(cd=_current_dir, sep=os.path.sep)
""" A wave file that contains a sentence between long leading and trailing periods of silence"""

View file

@ -0,0 +1,3 @@
class DuplicateArgument(Exception):
pass

517
libs/auditok/io.py Normal file
View file

@ -0,0 +1,517 @@
"""
Module for low-level audio input-output operations.
Class summary
=============
.. autosummary::
AudioSource
Rewindable
BufferAudioSource
WaveAudioSource
PyAudioSource
StdinAudioSource
PyAudioPlayer
Function summary
================
.. autosummary::
from_file
player_for
"""
from abc import ABCMeta, abstractmethod
import wave
import sys
__all__ = ["AudioSource", "Rewindable", "BufferAudioSource", "WaveAudioSource",
"PyAudioSource", "StdinAudioSource", "PyAudioPlayer", "from_file", "player_for"]
DEFAULT_SAMPLE_RATE = 16000
DEFAULT_SAMPLE_WIDTH = 2
DEFAULT_NB_CHANNELS = 1
class AudioSource():
"""
Base class for audio source objects.
Subclasses should implement methods to open/close and audio stream
and read the desired amount of audio samples.
:Parameters:
`sampling_rate` : int
Number of samples per second of audio stream. Default = 16000.
`sample_width` : int
Size in bytes of one audio sample. Possible values : 1, 2, 4.
Default = 2.
`channels` : int
Number of channels of audio stream. The current version supports
only mono audio streams (i.e. one channel).
"""
__metaclass__ = ABCMeta
def __init__(self, sampling_rate=DEFAULT_SAMPLE_RATE,
sample_width=DEFAULT_SAMPLE_WIDTH,
channels=DEFAULT_NB_CHANNELS):
if not sample_width in (1, 2, 4):
raise ValueError("Sample width must be one of: 1, 2 or 4 (bytes)")
if channels != 1:
raise ValueError("Only mono audio is currently handled")
self._sampling_rate = sampling_rate
self._sample_width = sample_width
self._channels = channels
@abstractmethod
def is_open(self):
""" Return True if audio source is open, False otherwise """
@abstractmethod
def open(self):
""" Open audio source """
@abstractmethod
def close(self):
""" Close audio source """
@abstractmethod
def read(self, size):
"""
Read and return `size` audio samples at most.
:Parameters:
`size` : int
the number of samples to read.
:Returns:
Audio data as a string of length 'N' * 'sample_width' * 'channels', where 'N' is:
- `size` if `size` < 'left_samples'
- 'left_samples' if `size` > 'left_samples'
"""
def get_sampling_rate(self):
""" Return the number of samples per second of audio stream """
return self.sampling_rate
@property
def sampling_rate(self):
""" Number of samples per second of audio stream """
return self._sampling_rate
@property
def sr(self):
""" Number of samples per second of audio stream """
return self._sampling_rate
def get_sample_width(self):
""" Return the number of bytes used to represent one audio sample """
return self.sample_width
@property
def sample_width(self):
""" Number of bytes used to represent one audio sample """
return self._sample_width
@property
def sw(self):
""" Number of bytes used to represent one audio sample """
return self._sample_width
def get_channels(self):
""" Return the number of channels of this audio source """
return self.channels
@property
def channels(self):
""" Number of channels of this audio source """
return self._channels
@property
def ch(self):
""" Return the number of channels of this audio source """
return self.channels
class Rewindable():
"""
Base class for rewindable audio streams.
Subclasses should implement methods to return to the beginning of an
audio stream as well as method to move to an absolute audio position
expressed in time or in number of samples.
"""
__metaclass__ = ABCMeta
@abstractmethod
def rewind(self):
""" Go back to the beginning of audio stream """
pass
@abstractmethod
def get_position(self):
""" Return the total number of already read samples """
@abstractmethod
def get_time_position(self):
""" Return the total duration in seconds of already read data """
@abstractmethod
def set_position(self, position):
""" Move to an absolute position
:Parameters:
`position` : int
number of samples to skip from the start of the stream
"""
@abstractmethod
def set_time_position(self, time_position):
""" Move to an absolute position expressed in seconds
:Parameters:
`time_position` : float
seconds to skip from the start of the stream
"""
pass
class BufferAudioSource(AudioSource, Rewindable):
"""
An :class:`AudioSource` that encapsulates and reads data from a memory buffer.
It implements methods from :class:`Rewindable` and is therefore a navigable :class:`AudioSource`.
"""
def __init__(self, data_buffer,
sampling_rate=DEFAULT_SAMPLE_RATE,
sample_width=DEFAULT_SAMPLE_WIDTH,
channels=DEFAULT_NB_CHANNELS):
if len(data_buffer) % (sample_width * channels) != 0:
raise ValueError("length of data_buffer must be a multiple of (sample_width * channels)")
AudioSource.__init__(self, sampling_rate, sample_width, channels)
self._buffer = data_buffer
self._index = 0
self._left = 0 if self._buffer is None else len(self._buffer)
self._is_open = False
def is_open(self):
return self._is_open
def open(self):
self._is_open = True
def close(self):
self._is_open = False
self.rewind()
def read(self, size):
if not self._is_open:
raise IOError("Stream is not open")
if self._left > 0:
to_read = size * self.sample_width * self.channels
if to_read > self._left:
to_read = self._left
data = self._buffer[self._index: self._index + to_read]
self._index += to_read
self._left -= to_read
return data
return None
def get_data_buffer(self):
""" Return all audio data as one string buffer. """
return self._buffer
def set_data(self, data_buffer):
""" Set new data for this audio stream.
:Parameters:
`data_buffer` : str, basestring, Bytes
a string buffer with a length multiple of (sample_width * channels)
"""
if len(data_buffer) % (self.sample_width * self.channels) != 0:
raise ValueError("length of data_buffer must be a multiple of (sample_width * channels)")
self._buffer = data_buffer
self._index = 0
self._left = 0 if self._buffer is None else len(self._buffer)
def append_data(self, data_buffer):
""" Append data to this audio stream
:Parameters:
`data_buffer` : str, basestring, Bytes
a buffer with a length multiple of (sample_width * channels)
"""
if len(data_buffer) % (self.sample_width * self.channels) != 0:
raise ValueError("length of data_buffer must be a multiple of (sample_width * channels)")
self._buffer += data_buffer
self._left += len(data_buffer)
def rewind(self):
self.set_position(0)
def get_position(self):
return self._index / self.sample_width
def get_time_position(self):
return float(self._index) / (self.sample_width * self.sampling_rate)
def set_position(self, position):
if position < 0:
raise ValueError("position must be >= 0")
if self._buffer is None:
self._index = 0
self._left = 0
return
position *= self.sample_width
self._index = position if position < len(self._buffer) else len(self._buffer)
self._left = len(self._buffer) - self._index
def set_time_position(self, time_position): # time in seconds
position = int(self.sampling_rate * time_position)
self.set_position(position)
class WaveAudioSource(AudioSource):
"""
A class for an `AudioSource` that reads data from a wave file.
:Parameters:
`filename` :
path to a valid wave file
"""
def __init__(self, filename):
self._filename = filename
self._audio_stream = None
stream = wave.open(self._filename)
AudioSource.__init__(self, stream.getframerate(),
stream.getsampwidth(),
stream.getnchannels())
stream.close()
def is_open(self):
return self._audio_stream is not None
def open(self):
if(self._audio_stream is None):
self._audio_stream = wave.open(self._filename)
def close(self):
if self._audio_stream is not None:
self._audio_stream.close()
self._audio_stream = None
def read(self, size):
if self._audio_stream is None:
raise IOError("Stream is not open")
else:
data = self._audio_stream.readframes(size)
if data is None or len(data) < 1:
return None
return data
class PyAudioSource(AudioSource):
"""
A class for an `AudioSource` that reads data the built-in microphone using PyAudio.
"""
def __init__(self, sampling_rate=DEFAULT_SAMPLE_RATE,
sample_width=DEFAULT_SAMPLE_WIDTH,
channels=DEFAULT_NB_CHANNELS,
frames_per_buffer=1024,
input_device_index=None):
AudioSource.__init__(self, sampling_rate, sample_width, channels)
self._chunk_size = frames_per_buffer
self.input_device_index = input_device_index
import pyaudio
self._pyaudio_object = pyaudio.PyAudio()
self._pyaudio_format = self._pyaudio_object.get_format_from_width(self.sample_width)
self._audio_stream = None
def is_open(self):
return self._audio_stream is not None
def open(self):
self._audio_stream = self._pyaudio_object.open(format=self._pyaudio_format,
channels=self.channels,
rate=self.sampling_rate,
input=True,
output=False,
input_device_index=self.input_device_index,
frames_per_buffer=self._chunk_size)
def close(self):
if self._audio_stream is not None:
self._audio_stream.stop_stream()
self._audio_stream.close()
self._audio_stream = None
def read(self, size):
if self._audio_stream is None:
raise IOError("Stream is not open")
if self._audio_stream.is_active():
data = self._audio_stream.read(size)
if data is None or len(data) < 1:
return None
return data
return None
class StdinAudioSource(AudioSource):
"""
A class for an :class:`AudioSource` that reads data from standard input.
"""
def __init__(self, sampling_rate=DEFAULT_SAMPLE_RATE,
sample_width=DEFAULT_SAMPLE_WIDTH,
channels=DEFAULT_NB_CHANNELS):
AudioSource.__init__(self, sampling_rate, sample_width, channels)
self._is_open = False
def is_open(self):
return self._is_open
def open(self):
self._is_open = True
def close(self):
self._is_open = False
def read(self, size):
if not self._is_open:
raise IOError("Stream is not open")
to_read = size * self.sample_width * self.channels
if sys.version_info >= (3, 0):
data = sys.stdin.buffer.read(to_read)
else:
data = sys.stdin.read(to_read)
if data is None or len(data) < 1:
return None
return data
class PyAudioPlayer():
"""
A class for audio playback using Pyaudio
"""
def __init__(self, sampling_rate=DEFAULT_SAMPLE_RATE,
sample_width=DEFAULT_SAMPLE_WIDTH,
channels=DEFAULT_NB_CHANNELS):
if not sample_width in (1, 2, 4):
raise ValueError("Sample width must be one of: 1, 2 or 4 (bytes)")
self.sampling_rate = sampling_rate
self.sample_width = sample_width
self.channels = channels
import pyaudio
self._p = pyaudio.PyAudio()
self.stream = self._p.open(format=self._p.get_format_from_width(self.sample_width),
channels=self.channels, rate=self.sampling_rate,
input=False, output=True)
def play(self, data):
if self.stream.is_stopped():
self.stream.start_stream()
for chunk in self._chunk_data(data):
self.stream.write(chunk)
self.stream.stop_stream()
def stop(self):
if not self.stream.is_stopped():
self.stream.stop_stream()
self.stream.close()
self._p.terminate()
def _chunk_data(self, data):
# make audio chunks of 100 ms to allow interruption (like ctrl+c)
chunk_size = int((self.sampling_rate * self.sample_width * self.channels) / 10)
start = 0
while start < len(data):
yield data[start: start + chunk_size]
start += chunk_size
def from_file(filename):
"""
Create an `AudioSource` object using the audio file specified by `filename`.
The appropriate :class:`AudioSource` class is guessed from file's extension.
:Parameters:
`filename` :
path to an audio file.
:Returns:
an `AudioSource` object that reads data from the given file.
"""
if filename.lower().endswith(".wav"):
return WaveAudioSource(filename)
raise Exception("Can not create an AudioSource object from '%s'" % (filename))
def player_for(audio_source):
"""
Return a :class:`PyAudioPlayer` that can play data from `audio_source`.
:Parameters:
`audio_source` :
an `AudioSource` object.
:Returns:
`PyAudioPlayer` that has the same sampling rate, sample width and number of channels
as `audio_source`.
"""
return PyAudioPlayer(audio_source.get_sampling_rate(),
audio_source.get_sample_width(),
audio_source.get_channels())

843
libs/auditok/util.py Normal file
View file

@ -0,0 +1,843 @@
"""
Class summary
=============
.. autosummary::
DataSource
StringDataSource
ADSFactory
ADSFactory.AudioDataSource
ADSFactory.ADSDecorator
ADSFactory.OverlapADS
ADSFactory.LimiterADS
ADSFactory.RecorderADS
DataValidator
AudioEnergyValidator
"""
from abc import ABCMeta, abstractmethod
import math
from array import array
from .io import Rewindable, from_file, BufferAudioSource, PyAudioSource
from .exceptions import DuplicateArgument
import sys
try:
import numpy
_WITH_NUMPY = True
except ImportError as e:
_WITH_NUMPY = False
try:
from builtins import str
basestring = str
except ImportError as e:
if sys.version_info >= (3, 0):
basestring = str
__all__ = ["DataSource", "DataValidator", "StringDataSource", "ADSFactory", "AudioEnergyValidator"]
class DataSource():
"""
Base class for objects passed to :func:`auditok.core.StreamTokenizer.tokenize`.
Subclasses should implement a :func:`DataSource.read` method.
"""
__metaclass__ = ABCMeta
@abstractmethod
def read(self):
"""
Read a piece of data read from this source.
If no more data is available, return None.
"""
class DataValidator():
"""
Base class for a validator object used by :class:`.core.StreamTokenizer` to check
if read data is valid.
Subclasses should implement :func:`is_valid` method.
"""
__metaclass__ = ABCMeta
@abstractmethod
def is_valid(self, data):
"""
Check whether `data` is valid
"""
class StringDataSource(DataSource):
"""
A class that represent a :class:`DataSource` as a string buffer.
Each call to :func:`DataSource.read` returns on character and moves one step forward.
If the end of the buffer is reached, :func:`read` returns None.
:Parameters:
`data` :
a basestring object.
"""
def __init__(self, data):
self._data = None
self._current = 0
self.set_data(data)
def read(self):
"""
Read one character from buffer.
:Returns:
Current character or None if end of buffer is reached
"""
if self._current >= len(self._data):
return None
self._current += 1
return self._data[self._current - 1]
def set_data(self, data):
"""
Set a new data buffer.
:Parameters:
`data` : a basestring object
New data buffer.
"""
if not isinstance(data, basestring):
raise ValueError("data must an instance of basestring")
self._data = data
self._current = 0
class ADSFactory:
"""
Factory class that makes it easy to create an :class:`ADSFactory.AudioDataSource` object that implements
:class:`DataSource` and can therefore be passed to :func:`auditok.core.StreamTokenizer.tokenize`.
Whether you read audio data from a file, the microphone or a memory buffer, this factory
instantiates and returns the right :class:`ADSFactory.AudioDataSource` object.
There are many other features you want your :class:`ADSFactory.AudioDataSource` object to have, such as:
memorize all read audio data so that you can rewind and reuse it (especially useful when
reading data from the microphone), read a fixed amount of data (also useful when reading
from the microphone), read overlapping audio frames (often needed when dosing a spectral
analysis of data).
:func:`ADSFactory.ads` automatically creates and return object with the desired behavior according
to the supplied keyword arguments.
"""
@staticmethod
def _check_normalize_args(kwargs):
for k in kwargs:
if not k in ["block_dur", "hop_dur", "block_size", "hop_size", "max_time", "record",
"audio_source", "filename", "data_buffer", "frames_per_buffer", "sampling_rate",
"sample_width", "channels", "sr", "sw", "ch", "asrc", "fn", "fpb", "db", "mt",
"rec", "bd", "hd", "bs", "hs"]:
raise ValueError("Invalid argument: {0}".format(k))
if "block_dur" in kwargs and "bd" in kwargs:
raise DuplicateArgument("Either 'block_dur' or 'bd' must be specified, not both")
if "hop_dur" in kwargs and "hd" in kwargs:
raise DuplicateArgument("Either 'hop_dur' or 'hd' must be specified, not both")
if "block_size" in kwargs and "bs" in kwargs:
raise DuplicateArgument("Either 'block_size' or 'bs' must be specified, not both")
if "hop_size" in kwargs and "hs" in kwargs:
raise DuplicateArgument("Either 'hop_size' or 'hs' must be specified, not both")
if "max_time" in kwargs and "mt" in kwargs:
raise DuplicateArgument("Either 'max_time' or 'mt' must be specified, not both")
if "audio_source" in kwargs and "asrc" in kwargs:
raise DuplicateArgument("Either 'audio_source' or 'asrc' must be specified, not both")
if "filename" in kwargs and "fn" in kwargs:
raise DuplicateArgument("Either 'filename' or 'fn' must be specified, not both")
if "data_buffer" in kwargs and "db" in kwargs:
raise DuplicateArgument("Either 'filename' or 'db' must be specified, not both")
if "frames_per_buffer" in kwargs and "fbb" in kwargs:
raise DuplicateArgument("Either 'frames_per_buffer' or 'fpb' must be specified, not both")
if "sampling_rate" in kwargs and "sr" in kwargs:
raise DuplicateArgument("Either 'sampling_rate' or 'sr' must be specified, not both")
if "sample_width" in kwargs and "sw" in kwargs:
raise DuplicateArgument("Either 'sample_width' or 'sw' must be specified, not both")
if "channels" in kwargs and "ch" in kwargs:
raise DuplicateArgument("Either 'channels' or 'ch' must be specified, not both")
if "record" in kwargs and "rec" in kwargs:
raise DuplicateArgument("Either 'record' or 'rec' must be specified, not both")
kwargs["bd"] = kwargs.pop("block_dur", None) or kwargs.pop("bd", None)
kwargs["hd"] = kwargs.pop("hop_dur", None) or kwargs.pop("hd", None)
kwargs["bs"] = kwargs.pop("block_size", None) or kwargs.pop("bs", None)
kwargs["hs"] = kwargs.pop("hop_size", None) or kwargs.pop("hs", None)
kwargs["mt"] = kwargs.pop("max_time", None) or kwargs.pop("mt", None)
kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop("asrc", None)
kwargs["fn"] = kwargs.pop("filename", None) or kwargs.pop("fn", None)
kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop("db", None)
record = kwargs.pop("record", False)
if not record:
record = kwargs.pop("rec", False)
if not isinstance(record, bool):
raise TypeError("'record' must be a boolean")
kwargs["rec"] = record
# keep long names for arguments meant for BufferAudioSource and PyAudioSource
if "frames_per_buffer" in kwargs or "fpb" in kwargs:
kwargs["frames_per_buffer"] = kwargs.pop("frames_per_buffer", None) or kwargs.pop("fpb", None)
if "sampling_rate" in kwargs or "sr" in kwargs:
kwargs["sampling_rate"] = kwargs.pop("sampling_rate", None) or kwargs.pop("sr", None)
if "sample_width" in kwargs or "sw" in kwargs:
kwargs["sample_width"] = kwargs.pop("sample_width", None) or kwargs.pop("sw", None)
if "channels" in kwargs or "ch" in kwargs:
kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop("ch", None)
@staticmethod
def ads(**kwargs):
"""
Create an return an :class:`ADSFactory.AudioDataSource`. The type and behavior of the object is the result
of the supplied parameters.
:Parameters:
*No parameters* :
read audio data from the available built-in microphone with the default parameters.
The returned :class:`ADSFactory.AudioDataSource` encapsulate an :class:`io.PyAudioSource` object and hence
it accepts the next four parameters are passed to use instead of their default values.
`sampling_rate`, `sr` : *(int)*
number of samples per second. Default = 16000.
`sample_width`, `sw` : *(int)*
number of bytes per sample (must be in (1, 2, 4)). Default = 2
`channels`, `ch` : *(int)*
number of audio channels. Default = 1 (only this value is currently accepted)
`frames_per_buffer`, `fpb` : *(int)*
number of samples of PyAudio buffer. Default = 1024.
`audio_source`, `asrc` : an `AudioSource` object
read data from this audio source
`filename`, `fn` : *(string)*
build an `io.AudioSource` object using this file (currently only wave format is supported)
`data_buffer`, `db` : *(string)*
build an `io.BufferAudioSource` using data in `data_buffer`. If this keyword is used,
`sampling_rate`, `sample_width` and `channels` are passed to `io.BufferAudioSource`
constructor and used instead of default values.
`max_time`, `mt` : *(float)*
maximum time (in seconds) to read. Default behavior: read until there is no more data
available.
`record`, `rec` : *(bool)*
save all read data in cache. Provide a navigable object which boasts a `rewind` method.
Default = False.
`block_dur`, `bd` : *(float)*
processing block duration in seconds. This represents the quantity of audio data to return
each time the :func:`read` method is invoked. If `block_dur` is 0.025 (i.e. 25 ms) and the sampling
rate is 8000 and the sample width is 2 bytes, :func:`read` returns a buffer of 0.025 * 8000 * 2 = 400
bytes at most. This parameter will be looked for (and used if available) before `block_size`.
If neither parameter is given, `block_dur` will be set to 0.01 second (i.e. 10 ms)
`hop_dur`, `hd` : *(float)*
quantity of data to skip from current processing window. if `hop_dur` is supplied then there
will be an overlap of `block_dur` - `hop_dur` between two adjacent blocks. This
parameter will be looked for (and used if available) before `hop_size`. If neither parameter
is given, `hop_dur` will be set to `block_dur` which means that there will be no overlap
between two consecutively read blocks.
`block_size`, `bs` : *(int)*
number of samples to read each time the `read` method is called. Default: a block size
that represents a window of 10ms, so for a sampling rate of 16000, the default `block_size`
is 160 samples, for a rate of 44100, `block_size` = 441 samples, etc.
`hop_size`, `hs` : *(int)*
determines the number of overlapping samples between two adjacent read windows. For a
`hop_size` of value *N*, the overlap is `block_size` - *N*. Default : `hop_size` = `block_size`,
means that there is no overlap.
:Returns:
An AudioDataSource object that has the desired features.
:Exampels:
1. **Create an AudioDataSource that reads data from the microphone (requires Pyaudio) with default audio parameters:**
.. code:: python
from auditok import ADSFactory
ads = ADSFactory.ads()
ads.get_sampling_rate()
16000
ads.get_sample_width()
2
ads.get_channels()
1
2. **Create an AudioDataSource that reads data from the microphone with a sampling rate of 48KHz:**
.. code:: python
from auditok import ADSFactory
ads = ADSFactory.ads(sr=48000)
ads.get_sampling_rate()
48000
3. **Create an AudioDataSource that reads data from a wave file:**
.. code:: python
import auditok
from auditok import ADSFactory
ads = ADSFactory.ads(fn=auditok.dataset.was_der_mensch_saet_mono_44100_lead_trail_silence)
ads.get_sampling_rate()
44100
ads.get_sample_width()
2
ads.get_channels()
1
4. **Define size of read blocks as 20 ms**
.. code:: python
import auditok
from auditok import ADSFactory
'''
we know samling rate for previous file is 44100 samples/second
so 10 ms are equivalent to 441 samples and 20 ms to 882
'''
block_size = 882
ads = ADSFactory.ads(bs = 882, fn=auditok.dataset.was_der_mensch_saet_mono_44100_lead_trail_silence)
ads.open()
# read one block
data = ads.read()
ads.close()
len(data)
1764
assert len(data) == ads.get_sample_width() * block_size
5. **Define block size as a duration (use block_dur or bd):**
.. code:: python
import auditok
from auditok import ADSFactory
dur = 0.25 # second
ads = ADSFactory.ads(bd = dur, fn=auditok.dataset.was_der_mensch_saet_mono_44100_lead_trail_silence)
'''
we know samling rate for previous file is 44100 samples/second
for a block duration of 250 ms, block size should be 0.25 * 44100 = 11025
'''
ads.get_block_size()
11025
assert ads.get_block_size() == int(0.25 * 44100)
ads.open()
# read one block
data = ads.read()
ads.close()
len(data)
22050
assert len(data) == ads.get_sample_width() * ads.get_block_size()
6. **Read overlapping blocks (one of hope_size, hs, hop_dur or hd > 0):**
For better readability we'd better use :class:`auditok.io.BufferAudioSource` with a string buffer:
.. code:: python
import auditok
from auditok import ADSFactory
'''
we supply a data beffer instead of a file (keyword 'bata_buffer' or 'db')
sr : sampling rate = 16 samples/sec
sw : sample width = 1 byte
ch : channels = 1
'''
buffer = "abcdefghijklmnop" # 16 bytes = 1 second of data
bd = 0.250 # block duration = 250 ms = 4 bytes
hd = 0.125 # hop duration = 125 ms = 2 bytes
ads = ADSFactory.ads(db = "abcdefghijklmnop", bd = bd, hd = hd, sr = 16, sw = 1, ch = 1)
ads.open()
ads.read()
'abcd'
ads.read()
'cdef'
ads.read()
'efgh'
ads.read()
'ghij'
data = ads.read()
assert data == 'ijkl'
7. **Limit amount of read data (use max_time or mt):**
.. code:: python
'''
We know audio file is larger than 2.25 seconds
We want to read up to 2.25 seconds of audio data
'''
ads = ADSFactory.ads(mt = 2.25, fn=auditok.dataset.was_der_mensch_saet_mono_44100_lead_trail_silence)
ads.open()
data = []
while True:
d = ads.read()
if d is None:
break
data.append(d)
ads.close()
data = b''.join(data)
assert len(data) == int(ads.get_sampling_rate() * 2.25 * ads.get_sample_width() * ads.get_channels())
"""
# copy user's dicionary (shallow copy)
kwargs = kwargs.copy()
# check and normalize keyword arguments
ADSFactory._check_normalize_args(kwargs)
block_dur = kwargs.pop("bd")
hop_dur = kwargs.pop("hd")
block_size = kwargs.pop("bs")
hop_size = kwargs.pop("hs")
max_time = kwargs.pop("mt")
audio_source = kwargs.pop("asrc")
filename = kwargs.pop("fn")
data_buffer = kwargs.pop("db")
record = kwargs.pop("rec")
# Case 1: an audio source is supplied
if audio_source is not None:
if (filename, data_buffer) != (None, None):
raise Warning("You should provide one of 'audio_source', 'filename' or 'data_buffer'\
keyword parameters. 'audio_source' will be used")
# Case 2: a file name is supplied
elif filename is not None:
if data_buffer is not None:
raise Warning("You should provide one of 'filename' or 'data_buffer'\
keyword parameters. 'filename' will be used")
audio_source = from_file(filename)
# Case 3: a data_buffer is supplied
elif data_buffer is not None:
audio_source = BufferAudioSource(data_buffer=data_buffer, **kwargs)
# Case 4: try to access native audio input
else:
audio_source = PyAudioSource(**kwargs)
if block_dur is not None:
if block_size is not None:
raise DuplicateArgument("Either 'block_dur' or 'block_size' can be specified, not both")
else:
block_size = int(audio_source.get_sampling_rate() * block_dur)
elif block_size is None:
# Set default block_size to 10 ms
block_size = int(audio_source.get_sampling_rate() / 100)
# Instantiate base AudioDataSource
ads = ADSFactory.AudioDataSource(audio_source=audio_source, block_size=block_size)
# Limit data to be read
if max_time is not None:
ads = ADSFactory.LimiterADS(ads=ads, max_time=max_time)
# Record, rewind and reuse data
if record:
ads = ADSFactory.RecorderADS(ads=ads)
# Read overlapping blocks of data
if hop_dur is not None:
if hop_size is not None:
raise DuplicateArgument("Either 'hop_dur' or 'hop_size' can be specified, not both")
else:
hop_size = int(audio_source.get_sampling_rate() * hop_dur)
if hop_size is not None:
if hop_size <= 0 or hop_size > block_size:
raise ValueError("hop_size must be > 0 and <= block_size")
if hop_size < block_size:
ads = ADSFactory.OverlapADS(ads=ads, hop_size=hop_size)
return ads
class AudioDataSource(DataSource):
"""
Base class for AudioDataSource objects.
It inherits from DataSource and encapsulates an AudioSource object.
"""
def __init__(self, audio_source, block_size):
self.audio_source = audio_source
self.block_size = block_size
def get_block_size(self):
return self.block_size
def set_block_size(self, size):
self.block_size = size
def get_audio_source(self):
return self.audio_source
def set_audio_source(self, audio_source):
self.audio_source = audio_source
def open(self):
self.audio_source.open()
def close(self):
self.audio_source.close()
def is_open(self):
return self.audio_source.is_open()
def get_sampling_rate(self):
return self.audio_source.get_sampling_rate()
def get_sample_width(self):
return self.audio_source.get_sample_width()
def get_channels(self):
return self.audio_source.get_channels()
def rewind(self):
if isinstance(self.audio_source, Rewindable):
self.audio_source.rewind()
else:
raise Exception("Audio source is not rewindable")
def is_rewindable(self):
return isinstance(self.audio_source, Rewindable)
def read(self):
return self.audio_source.read(self.block_size)
class ADSDecorator(AudioDataSource):
"""
Base decorator class for AudioDataSource objects.
"""
__metaclass__ = ABCMeta
def __init__(self, ads):
self.ads = ads
self.get_block_size = self.ads.get_block_size
self.set_block_size = self.ads.set_block_size
self.get_audio_source = self.ads.get_audio_source
self.open = self.ads.open
self.close = self.ads.close
self.is_open = self.ads.is_open
self.get_sampling_rate = self.ads.get_sampling_rate
self.get_sample_width = self.ads.get_sample_width
self.get_channels = self.ads.get_channels
def is_rewindable(self):
return self.ads.is_rewindable
def rewind(self):
self.ads.rewind()
self._reinit()
def set_audio_source(self, audio_source):
self.ads.set_audio_source(audio_source)
self._reinit()
def open(self):
if not self.ads.is_open():
self.ads.open()
self._reinit()
@abstractmethod
def _reinit(self):
pass
class OverlapADS(ADSDecorator):
"""
A class for AudioDataSource objects that can read and return overlapping
audio frames
"""
def __init__(self, ads, hop_size):
ADSFactory.ADSDecorator.__init__(self, ads)
if hop_size <= 0 or hop_size > self.get_block_size():
raise ValueError("hop_size must be either 'None' or \
between 1 and block_size (both inclusive)")
self.hop_size = hop_size
self._actual_block_size = self.get_block_size()
self._reinit()
def _get_block_size():
return self._actual_block_size
def _read_first_block(self):
# For the first call, we need an entire block of size 'block_size'
block = self.ads.read()
if block is None:
return None
# Keep a slice of data in cache and append it in the next call
if len(block) > self._hop_size_bytes:
self._cache = block[self._hop_size_bytes:]
# Up from the next call, we will use '_read_next_blocks'
# and we only read 'hop_size'
self.ads.set_block_size(self.hop_size)
self.read = self._read_next_blocks
return block
def _read_next_blocks(self):
block = self.ads.read()
if block is None:
return None
# Append block to cache data to ensure overlap
block = self._cache + block
# Keep a slice of data in cache only if we have a full length block
# if we don't that means that this is the last block
if len(block) == self._block_size_bytes:
self._cache = block[self._hop_size_bytes:]
else:
self._cache = None
return block
def read(self):
pass
def _reinit(self):
self._cache = None
self.ads.set_block_size(self._actual_block_size)
self._hop_size_bytes = self.hop_size * \
self.get_sample_width() * \
self.get_channels()
self._block_size_bytes = self.get_block_size() * \
self.get_sample_width() * \
self.get_channels()
self.read = self._read_first_block
class LimiterADS(ADSDecorator):
"""
A class for AudioDataSource objects that can read a fixed amount of data.
This can be useful when reading data from the microphone or from large audio files.
"""
def __init__(self, ads, max_time):
ADSFactory.ADSDecorator.__init__(self, ads)
self.max_time = max_time
self._reinit()
def read(self):
if self._total_read_bytes >= self._max_read_bytes:
return None
block = self.ads.read()
if block is None:
return None
self._total_read_bytes += len(block)
if self._total_read_bytes >= self._max_read_bytes:
self.close()
return block
def _reinit(self):
self._max_read_bytes = int(self.max_time * self.get_sampling_rate()) * \
self.get_sample_width() * \
self.get_channels()
self._total_read_bytes = 0
class RecorderADS(ADSDecorator):
"""
A class for AudioDataSource objects that can record all audio data they read,
with a rewind facility.
"""
def __init__(self, ads):
ADSFactory.ADSDecorator.__init__(self, ads)
self._reinit()
def read(self):
pass
def _read_and_rec(self):
# Read and save read data
block = self.ads.read()
if block is not None:
self._cache.append(block)
return block
def _read_simple(self):
# Read without recording
return self.ads.read()
def rewind(self):
if self._record:
# If has been recording, create a new BufferAudioSource
# from recorded data
dbuffer = self._concatenate(self._cache)
asource = BufferAudioSource(dbuffer, self.get_sampling_rate(),
self.get_sample_width(),
self.get_channels())
self.set_audio_source(asource)
self.open()
self._cache = []
self._record = False
self.read = self._read_simple
else:
self.ads.rewind()
if not self.is_open():
self.open()
def is_rewindable(self):
return True
def _reinit(self):
# when audio_source is replaced, start recording again
self._record = True
self._cache = []
self.read = self._read_and_rec
def _concatenate(self, data):
try:
# should always work for python 2
# work for python 3 ONLY if data is a list (or an iterator)
# whose each element is a 'bytes' objects
return b''.join(data)
except TypeError:
# work for 'str' in python 2 and python 3
return ''.join(data)
class AudioEnergyValidator(DataValidator):
"""
The most basic auditok audio frame validator.
This validator computes the log energy of an input audio frame
and return True if the result is >= a given threshold, False
otherwise.
:Parameters:
`sample_width` : *(int)*
Number of bytes of one audio sample. This is used to convert data from `basestring` or `Bytes` to
an array of floats.
`energy_threshold` : *(float)*
A threshold used to check whether an input data buffer is valid.
"""
if _WITH_NUMPY:
_formats = {1: numpy.int8, 2: numpy.int16, 4: numpy.int32}
@staticmethod
def _convert(signal, sample_width):
return numpy.array(numpy.frombuffer(signal, dtype=AudioEnergyValidator._formats[sample_width]),
dtype=numpy.float64)
@staticmethod
def _signal_energy(signal):
return float(numpy.dot(signal, signal)) / len(signal)
@staticmethod
def _signal_log_energy(signal):
energy = AudioEnergyValidator._signal_energy(signal)
if energy <= 0:
return -200
return 10. * numpy.log10(energy)
else:
_formats = {1: 'b', 2: 'h', 4: 'i'}
@staticmethod
def _convert(signal, sample_width):
return array("d", array(AudioEnergyValidator._formats[sample_width], signal))
@staticmethod
def _signal_energy(signal):
energy = 0.
for a in signal:
energy += a * a
return energy / len(signal)
@staticmethod
def _signal_log_energy(signal):
energy = AudioEnergyValidator._signal_energy(signal)
if energy <= 0:
return -200
return 10. * math.log10(energy)
def __init__(self, sample_width, energy_threshold=45):
self.sample_width = sample_width
self._energy_threshold = energy_threshold
def is_valid(self, data):
"""
Check if data is valid. Audio data will be converted into an array (of
signed values) of which the log energy is computed. Log energy is computed
as follows:
.. code:: python
arr = AudioEnergyValidator._convert(signal, sample_width)
energy = float(numpy.dot(arr, arr)) / len(arr)
log_energy = 10. * numpy.log10(energy)
:Parameters:
`data` : either a *string* or a *Bytes* buffer
`data` is converted into a numerical array using the `sample_width`
given in the constructor.
:Returns:
True if `log_energy` >= `energy_threshold`, False otherwise.
"""
signal = AudioEnergyValidator._convert(data, self.sample_width)
return AudioEnergyValidator._signal_log_energy(signal) >= self._energy_threshold
def get_energy_threshold(self):
return self._energy_threshold
def set_energy_threshold(self, threshold):
self._energy_threshold = threshold

22
libs/ffmpeg/__init__.py Normal file
View file

@ -0,0 +1,22 @@
from __future__ import unicode_literals
from . import nodes
from . import _ffmpeg
from . import _filters
from . import _probe
from . import _run
from . import _view
from .nodes import *
from ._ffmpeg import *
from ._filters import *
from ._probe import *
from ._run import *
from ._view import *
__all__ = (
nodes.__all__
+ _ffmpeg.__all__
+ _probe.__all__
+ _run.__all__
+ _view.__all__
+ _filters.__all__
)

97
libs/ffmpeg/_ffmpeg.py Normal file
View file

@ -0,0 +1,97 @@
from __future__ import unicode_literals
from past.builtins import basestring
from ._utils import basestring
from .nodes import (
filter_operator,
GlobalNode,
InputNode,
MergeOutputsNode,
OutputNode,
output_operator,
)
def input(filename, **kwargs):
"""Input file URL (ffmpeg ``-i`` option)
Any supplied kwargs are passed to ffmpeg verbatim (e.g. ``t=20``,
``f='mp4'``, ``acodec='pcm'``, etc.).
To tell ffmpeg to read from stdin, use ``pipe:`` as the filename.
Official documentation: `Main options <https://ffmpeg.org/ffmpeg.html#Main-options>`__
"""
kwargs['filename'] = filename
fmt = kwargs.pop('f', None)
if fmt:
if 'format' in kwargs:
raise ValueError("Can't specify both `format` and `f` kwargs")
kwargs['format'] = fmt
return InputNode(input.__name__, kwargs=kwargs).stream()
@output_operator()
def global_args(stream, *args):
"""Add extra global command-line argument(s), e.g. ``-progress``.
"""
return GlobalNode(stream, global_args.__name__, args).stream()
@output_operator()
def overwrite_output(stream):
"""Overwrite output files without asking (ffmpeg ``-y`` option)
Official documentation: `Main options <https://ffmpeg.org/ffmpeg.html#Main-options>`__
"""
return GlobalNode(stream, overwrite_output.__name__, ['-y']).stream()
@output_operator()
def merge_outputs(*streams):
"""Include all given outputs in one ffmpeg command line
"""
return MergeOutputsNode(streams, merge_outputs.__name__).stream()
@filter_operator()
def output(*streams_and_filename, **kwargs):
"""Output file URL
Syntax:
`ffmpeg.output(stream1[, stream2, stream3...], filename, **ffmpeg_args)`
Any supplied keyword arguments are passed to ffmpeg verbatim (e.g.
``t=20``, ``f='mp4'``, ``acodec='pcm'``, ``vcodec='rawvideo'``,
etc.). Some keyword-arguments are handled specially, as shown below.
Args:
video_bitrate: parameter for ``-b:v``, e.g. ``video_bitrate=1000``.
audio_bitrate: parameter for ``-b:a``, e.g. ``audio_bitrate=200``.
format: alias for ``-f`` parameter, e.g. ``format='mp4'``
(equivalent to ``f='mp4'``).
If multiple streams are provided, they are mapped to the same
output.
To tell ffmpeg to write to stdout, use ``pipe:`` as the filename.
Official documentation: `Synopsis <https://ffmpeg.org/ffmpeg.html#Synopsis>`__
"""
streams_and_filename = list(streams_and_filename)
if 'filename' not in kwargs:
if not isinstance(streams_and_filename[-1], basestring):
raise ValueError('A filename must be provided')
kwargs['filename'] = streams_and_filename.pop(-1)
streams = streams_and_filename
fmt = kwargs.pop('f', None)
if fmt:
if 'format' in kwargs:
raise ValueError("Can't specify both `format` and `f` kwargs")
kwargs['format'] = fmt
return OutputNode(streams, output.__name__, kwargs=kwargs).stream()
__all__ = ['input', 'merge_outputs', 'output', 'overwrite_output']

461
libs/ffmpeg/_filters.py Normal file
View file

@ -0,0 +1,461 @@
from __future__ import unicode_literals
from .nodes import FilterNode, filter_operator
from ._utils import escape_chars
@filter_operator()
def filter_multi_output(stream_spec, filter_name, *args, **kwargs):
"""Apply custom filter with one or more outputs.
This is the same as ``filter`` except that the filter can produce more than one output.
To reference an output stream, use either the ``.stream`` operator or bracket shorthand:
Example:
```
split = ffmpeg.input('in.mp4').filter_multi_output('split')
split0 = split.stream(0)
split1 = split[1]
ffmpeg.concat(split0, split1).output('out.mp4').run()
```
"""
return FilterNode(
stream_spec, filter_name, args=args, kwargs=kwargs, max_inputs=None
)
@filter_operator()
def filter(stream_spec, filter_name, *args, **kwargs):
"""Apply custom filter.
``filter_`` is normally used by higher-level filter functions such as ``hflip``, but if a filter implementation
is missing from ``ffmpeg-python``, you can call ``filter_`` directly to have ``ffmpeg-python`` pass the filter name
and arguments to ffmpeg verbatim.
Args:
stream_spec: a Stream, list of Streams, or label-to-Stream dictionary mapping
filter_name: ffmpeg filter name, e.g. `colorchannelmixer`
*args: list of args to pass to ffmpeg verbatim
**kwargs: list of keyword-args to pass to ffmpeg verbatim
The function name is suffixed with ``_`` in order avoid confusion with the standard python ``filter`` function.
Example:
``ffmpeg.input('in.mp4').filter('hflip').output('out.mp4').run()``
"""
return filter_multi_output(stream_spec, filter_name, *args, **kwargs).stream()
@filter_operator()
def filter_(stream_spec, filter_name, *args, **kwargs):
"""Alternate name for ``filter``, so as to not collide with the
built-in python ``filter`` operator.
"""
return filter(stream_spec, filter_name, *args, **kwargs)
@filter_operator()
def split(stream):
return FilterNode(stream, split.__name__)
@filter_operator()
def asplit(stream):
return FilterNode(stream, asplit.__name__)
@filter_operator()
def setpts(stream, expr):
"""Change the PTS (presentation timestamp) of the input frames.
Args:
expr: The expression which is evaluated for each frame to construct its timestamp.
Official documentation: `setpts, asetpts <https://ffmpeg.org/ffmpeg-filters.html#setpts_002c-asetpts>`__
"""
return FilterNode(stream, setpts.__name__, args=[expr]).stream()
@filter_operator()
def trim(stream, **kwargs):
"""Trim the input so that the output contains one continuous subpart of the input.
Args:
start: Specify the time of the start of the kept section, i.e. the frame with the timestamp start will be the
first frame in the output.
end: Specify the time of the first frame that will be dropped, i.e. the frame immediately preceding the one
with the timestamp end will be the last frame in the output.
start_pts: This is the same as start, except this option sets the start timestamp in timebase units instead of
seconds.
end_pts: This is the same as end, except this option sets the end timestamp in timebase units instead of
seconds.
duration: The maximum duration of the output in seconds.
start_frame: The number of the first frame that should be passed to the output.
end_frame: The number of the first frame that should be dropped.
Official documentation: `trim <https://ffmpeg.org/ffmpeg-filters.html#trim>`__
"""
return FilterNode(stream, trim.__name__, kwargs=kwargs).stream()
@filter_operator()
def overlay(main_parent_node, overlay_parent_node, eof_action='repeat', **kwargs):
"""Overlay one video on top of another.
Args:
x: Set the expression for the x coordinates of the overlaid video on the main video. Default value is 0. In
case the expression is invalid, it is set to a huge value (meaning that the overlay will not be displayed
within the output visible area).
y: Set the expression for the y coordinates of the overlaid video on the main video. Default value is 0. In
case the expression is invalid, it is set to a huge value (meaning that the overlay will not be displayed
within the output visible area).
eof_action: The action to take when EOF is encountered on the secondary input; it accepts one of the following
values:
* ``repeat``: Repeat the last frame (the default).
* ``endall``: End both streams.
* ``pass``: Pass the main input through.
eval: Set when the expressions for x, and y are evaluated.
It accepts the following values:
* ``init``: only evaluate expressions once during the filter initialization or when a command is
processed
* ``frame``: evaluate expressions for each incoming frame
Default value is ``frame``.
shortest: If set to 1, force the output to terminate when the shortest input terminates. Default value is 0.
format: Set the format for the output video.
It accepts the following values:
* ``yuv420``: force YUV420 output
* ``yuv422``: force YUV422 output
* ``yuv444``: force YUV444 output
* ``rgb``: force packed RGB output
* ``gbrp``: force planar RGB output
Default value is ``yuv420``.
rgb (deprecated): If set to 1, force the filter to accept inputs in the RGB color space. Default value is 0.
This option is deprecated, use format instead.
repeatlast: If set to 1, force the filter to draw the last overlay frame over the main input until the end of
the stream. A value of 0 disables this behavior. Default value is 1.
Official documentation: `overlay <https://ffmpeg.org/ffmpeg-filters.html#overlay-1>`__
"""
kwargs['eof_action'] = eof_action
return FilterNode(
[main_parent_node, overlay_parent_node],
overlay.__name__,
kwargs=kwargs,
max_inputs=2,
).stream()
@filter_operator()
def hflip(stream):
"""Flip the input video horizontally.
Official documentation: `hflip <https://ffmpeg.org/ffmpeg-filters.html#hflip>`__
"""
return FilterNode(stream, hflip.__name__).stream()
@filter_operator()
def vflip(stream):
"""Flip the input video vertically.
Official documentation: `vflip <https://ffmpeg.org/ffmpeg-filters.html#vflip>`__
"""
return FilterNode(stream, vflip.__name__).stream()
@filter_operator()
def crop(stream, x, y, width, height, **kwargs):
"""Crop the input video.
Args:
x: The horizontal position, in the input video, of the left edge of
the output video.
y: The vertical position, in the input video, of the top edge of the
output video.
width: The width of the output video. Must be greater than 0.
heigth: The height of the output video. Must be greater than 0.
Official documentation: `crop <https://ffmpeg.org/ffmpeg-filters.html#crop>`__
"""
return FilterNode(
stream, crop.__name__, args=[width, height, x, y], kwargs=kwargs
).stream()
@filter_operator()
def drawbox(stream, x, y, width, height, color, thickness=None, **kwargs):
"""Draw a colored box on the input image.
Args:
x: The expression which specifies the top left corner x coordinate of the box. It defaults to 0.
y: The expression which specifies the top left corner y coordinate of the box. It defaults to 0.
width: Specify the width of the box; if 0 interpreted as the input width. It defaults to 0.
heigth: Specify the height of the box; if 0 interpreted as the input height. It defaults to 0.
color: Specify the color of the box to write. For the general syntax of this option, check the "Color" section
in the ffmpeg-utils manual. If the special value invert is used, the box edge color is the same as the
video with inverted luma.
thickness: The expression which sets the thickness of the box edge. Default value is 3.
w: Alias for ``width``.
h: Alias for ``height``.
c: Alias for ``color``.
t: Alias for ``thickness``.
Official documentation: `drawbox <https://ffmpeg.org/ffmpeg-filters.html#drawbox>`__
"""
if thickness:
kwargs['t'] = thickness
return FilterNode(
stream, drawbox.__name__, args=[x, y, width, height, color], kwargs=kwargs
).stream()
@filter_operator()
def drawtext(stream, text=None, x=0, y=0, escape_text=True, **kwargs):
"""Draw a text string or text from a specified file on top of a video, using the libfreetype library.
To enable compilation of this filter, you need to configure FFmpeg with ``--enable-libfreetype``. To enable default
font fallback and the font option you need to configure FFmpeg with ``--enable-libfontconfig``. To enable the
text_shaping option, you need to configure FFmpeg with ``--enable-libfribidi``.
Args:
box: Used to draw a box around text using the background color. The value must be either 1 (enable) or 0
(disable). The default value of box is 0.
boxborderw: Set the width of the border to be drawn around the box using boxcolor. The default value of
boxborderw is 0.
boxcolor: The color to be used for drawing box around text. For the syntax of this option, check the "Color"
section in the ffmpeg-utils manual. The default value of boxcolor is "white".
line_spacing: Set the line spacing in pixels of the border to be drawn around the box using box. The default
value of line_spacing is 0.
borderw: Set the width of the border to be drawn around the text using bordercolor. The default value of
borderw is 0.
bordercolor: Set the color to be used for drawing border around text. For the syntax of this option, check the
"Color" section in the ffmpeg-utils manual. The default value of bordercolor is "black".
expansion: Select how the text is expanded. Can be either none, strftime (deprecated) or normal (default). See
the Text expansion section below for details.
basetime: Set a start time for the count. Value is in microseconds. Only applied in the deprecated strftime
expansion mode. To emulate in normal expansion mode use the pts function, supplying the start time (in
seconds) as the second argument.
fix_bounds: If true, check and fix text coords to avoid clipping.
fontcolor: The color to be used for drawing fonts. For the syntax of this option, check the "Color" section in
the ffmpeg-utils manual. The default value of fontcolor is "black".
fontcolor_expr: String which is expanded the same way as text to obtain dynamic fontcolor value. By default
this option has empty value and is not processed. When this option is set, it overrides fontcolor option.
font: The font family to be used for drawing text. By default Sans.
fontfile: The font file to be used for drawing text. The path must be included. This parameter is mandatory if
the fontconfig support is disabled.
alpha: Draw the text applying alpha blending. The value can be a number between 0.0 and 1.0. The expression
accepts the same variables x, y as well. The default value is 1. Please see fontcolor_expr.
fontsize: The font size to be used for drawing text. The default value of fontsize is 16.
text_shaping: If set to 1, attempt to shape the text (for example, reverse the order of right-to-left text and
join Arabic characters) before drawing it. Otherwise, just draw the text exactly as given. By default 1 (if
supported).
ft_load_flags: The flags to be used for loading the fonts. The flags map the corresponding flags supported by
libfreetype, and are a combination of the following values:
* ``default``
* ``no_scale``
* ``no_hinting``
* ``render``
* ``no_bitmap``
* ``vertical_layout``
* ``force_autohint``
* ``crop_bitmap``
* ``pedantic``
* ``ignore_global_advance_width``
* ``no_recurse``
* ``ignore_transform``
* ``monochrome``
* ``linear_design``
* ``no_autohint``
Default value is "default". For more information consult the documentation for the FT_LOAD_* libfreetype
flags.
shadowcolor: The color to be used for drawing a shadow behind the drawn text. For the syntax of this option,
check the "Color" section in the ffmpeg-utils manual. The default value of shadowcolor is "black".
shadowx: The x offset for the text shadow position with respect to the position of the text. It can be either
positive or negative values. The default value is "0".
shadowy: The y offset for the text shadow position with respect to the position of the text. It can be either
positive or negative values. The default value is "0".
start_number: The starting frame number for the n/frame_num variable. The default value is "0".
tabsize: The size in number of spaces to use for rendering the tab. Default value is 4.
timecode: Set the initial timecode representation in "hh:mm:ss[:;.]ff" format. It can be used with or without
text parameter. timecode_rate option must be specified.
rate: Set the timecode frame rate (timecode only).
timecode_rate: Alias for ``rate``.
r: Alias for ``rate``.
tc24hmax: If set to 1, the output of the timecode option will wrap around at 24 hours. Default is 0 (disabled).
text: The text string to be drawn. The text must be a sequence of UTF-8 encoded characters. This parameter is
mandatory if no file is specified with the parameter textfile.
textfile: A text file containing text to be drawn. The text must be a sequence of UTF-8 encoded characters.
This parameter is mandatory if no text string is specified with the parameter text. If both text and
textfile are specified, an error is thrown.
reload: If set to 1, the textfile will be reloaded before each frame. Be sure to update it atomically, or it
may be read partially, or even fail.
x: The expression which specifies the offset where text will be drawn within the video frame. It is relative to
the left border of the output image. The default value is "0".
y: The expression which specifies the offset where text will be drawn within the video frame. It is relative to
the top border of the output image. The default value is "0". See below for the list of accepted constants
and functions.
Expression constants:
The parameters for x and y are expressions containing the following constants and functions:
- dar: input display aspect ratio, it is the same as ``(w / h) * sar``
- hsub: horizontal chroma subsample values. For example for the pixel format "yuv422p" hsub is 2 and vsub
is 1.
- vsub: vertical chroma subsample values. For example for the pixel format "yuv422p" hsub is 2 and vsub
is 1.
- line_h: the height of each text line
- lh: Alias for ``line_h``.
- main_h: the input height
- h: Alias for ``main_h``.
- H: Alias for ``main_h``.
- main_w: the input width
- w: Alias for ``main_w``.
- W: Alias for ``main_w``.
- ascent: the maximum distance from the baseline to the highest/upper grid coordinate used to place a glyph
outline point, for all the rendered glyphs. It is a positive value, due to the grid's orientation with the Y
axis upwards.
- max_glyph_a: Alias for ``ascent``.
- descent: the maximum distance from the baseline to the lowest grid coordinate used to place a glyph outline
point, for all the rendered glyphs. This is a negative value, due to the grid's orientation, with the Y axis
upwards.
- max_glyph_d: Alias for ``descent``.
- max_glyph_h: maximum glyph height, that is the maximum height for all the glyphs contained in the rendered
text, it is equivalent to ascent - descent.
- max_glyph_w: maximum glyph width, that is the maximum width for all the glyphs contained in the rendered
text.
- n: the number of input frame, starting from 0
- rand(min, max): return a random number included between min and max
- sar: The input sample aspect ratio.
- t: timestamp expressed in seconds, NAN if the input timestamp is unknown
- text_h: the height of the rendered text
- th: Alias for ``text_h``.
- text_w: the width of the rendered text
- tw: Alias for ``text_w``.
- x: the x offset coordinates where the text is drawn.
- y: the y offset coordinates where the text is drawn.
These parameters allow the x and y expressions to refer each other, so you can for example specify
``y=x/dar``.
Official documentation: `drawtext <https://ffmpeg.org/ffmpeg-filters.html#drawtext>`__
"""
if text is not None:
if escape_text:
text = escape_chars(text, '\\\'%')
kwargs['text'] = text
if x != 0:
kwargs['x'] = x
if y != 0:
kwargs['y'] = y
return filter(stream, drawtext.__name__, **kwargs)
@filter_operator()
def concat(*streams, **kwargs):
"""Concatenate audio and video streams, joining them together one after the other.
The filter works on segments of synchronized video and audio streams. All segments must have the same number of
streams of each type, and that will also be the number of streams at output.
Args:
unsafe: Activate unsafe mode: do not fail if segments have a different format.
Related streams do not always have exactly the same duration, for various reasons including codec frame size or
sloppy authoring. For that reason, related synchronized streams (e.g. a video and its audio track) should be
concatenated at once. The concat filter will use the duration of the longest stream in each segment (except the
last one), and if necessary pad shorter audio streams with silence.
For this filter to work correctly, all segments must start at timestamp 0.
All corresponding streams must have the same parameters in all segments; the filtering system will automatically
select a common pixel format for video streams, and a common sample format, sample rate and channel layout for
audio streams, but other settings, such as resolution, must be converted explicitly by the user.
Different frame rates are acceptable but will result in variable frame rate at output; be sure to configure the
output file to handle it.
Official documentation: `concat <https://ffmpeg.org/ffmpeg-filters.html#concat>`__
"""
video_stream_count = kwargs.get('v', 1)
audio_stream_count = kwargs.get('a', 0)
stream_count = video_stream_count + audio_stream_count
if len(streams) % stream_count != 0:
raise ValueError(
'Expected concat input streams to have length multiple of {} (v={}, a={}); got {}'.format(
stream_count, video_stream_count, audio_stream_count, len(streams)
)
)
kwargs['n'] = int(len(streams) / stream_count)
return FilterNode(streams, concat.__name__, kwargs=kwargs, max_inputs=None).stream()
@filter_operator()
def zoompan(stream, **kwargs):
"""Apply Zoom & Pan effect.
Args:
zoom: Set the zoom expression. Default is 1.
x: Set the x expression. Default is 0.
y: Set the y expression. Default is 0.
d: Set the duration expression in number of frames. This sets for how many number of frames effect will last
for single input image.
s: Set the output image size, default is ``hd720``.
fps: Set the output frame rate, default is 25.
z: Alias for ``zoom``.
Official documentation: `zoompan <https://ffmpeg.org/ffmpeg-filters.html#zoompan>`__
"""
return FilterNode(stream, zoompan.__name__, kwargs=kwargs).stream()
@filter_operator()
def hue(stream, **kwargs):
"""Modify the hue and/or the saturation of the input.
Args:
h: Specify the hue angle as a number of degrees. It accepts an expression, and defaults to "0".
s: Specify the saturation in the [-10,10] range. It accepts an expression and defaults to "1".
H: Specify the hue angle as a number of radians. It accepts an expression, and defaults to "0".
b: Specify the brightness in the [-10,10] range. It accepts an expression and defaults to "0".
Official documentation: `hue <https://ffmpeg.org/ffmpeg-filters.html#hue>`__
"""
return FilterNode(stream, hue.__name__, kwargs=kwargs).stream()
@filter_operator()
def colorchannelmixer(stream, *args, **kwargs):
"""Adjust video input frames by re-mixing color channels.
Official documentation: `colorchannelmixer <https://ffmpeg.org/ffmpeg-filters.html#colorchannelmixer>`__
"""
return FilterNode(stream, colorchannelmixer.__name__, kwargs=kwargs).stream()
__all__ = [
'colorchannelmixer',
'concat',
'crop',
'drawbox',
'drawtext',
'filter',
'filter_',
'filter_multi_output',
'hflip',
'hue',
'overlay',
'setpts',
'trim',
'vflip',
'zoompan',
]

27
libs/ffmpeg/_probe.py Normal file
View file

@ -0,0 +1,27 @@
import json
import subprocess
from ._run import Error
from ._utils import convert_kwargs_to_cmd_line_args
def probe(filename, cmd='ffprobe', **kwargs):
"""Run ffprobe on the specified file and return a JSON representation of the output.
Raises:
:class:`ffmpeg.Error`: if ffprobe returns a non-zero exit code,
an :class:`Error` is returned with a generic error message.
The stderr output can be retrieved by accessing the
``stderr`` property of the exception.
"""
args = [cmd, '-show_format', '-show_streams', '-of', 'json']
args += convert_kwargs_to_cmd_line_args(kwargs)
args += [filename]
p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()
if p.returncode != 0:
raise Error('ffprobe', out, err)
return json.loads(out.decode('utf-8'))
__all__ = ['probe']

329
libs/ffmpeg/_run.py Normal file
View file

@ -0,0 +1,329 @@
from __future__ import unicode_literals
from .dag import get_outgoing_edges, topo_sort
from ._utils import basestring, convert_kwargs_to_cmd_line_args
from builtins import str
from functools import reduce
import collections
import copy
import operator
import subprocess
from ._ffmpeg import input, output
from .nodes import (
get_stream_spec_nodes,
FilterNode,
GlobalNode,
InputNode,
OutputNode,
output_operator,
)
class Error(Exception):
def __init__(self, cmd, stdout, stderr):
super(Error, self).__init__(
'{} error (see stderr output for detail)'.format(cmd)
)
self.stdout = stdout
self.stderr = stderr
def _get_input_args(input_node):
if input_node.name == input.__name__:
kwargs = copy.copy(input_node.kwargs)
filename = kwargs.pop('filename')
fmt = kwargs.pop('format', None)
video_size = kwargs.pop('video_size', None)
args = []
if fmt:
args += ['-f', fmt]
if video_size:
args += ['-video_size', '{}x{}'.format(video_size[0], video_size[1])]
args += convert_kwargs_to_cmd_line_args(kwargs)
args += ['-i', filename]
else:
raise ValueError('Unsupported input node: {}'.format(input_node))
return args
def _format_input_stream_name(stream_name_map, edge, is_final_arg=False):
prefix = stream_name_map[edge.upstream_node, edge.upstream_label]
if not edge.upstream_selector:
suffix = ''
else:
suffix = ':{}'.format(edge.upstream_selector)
if is_final_arg and isinstance(edge.upstream_node, InputNode):
## Special case: `-map` args should not have brackets for input
## nodes.
fmt = '{}{}'
else:
fmt = '[{}{}]'
return fmt.format(prefix, suffix)
def _format_output_stream_name(stream_name_map, edge):
return '[{}]'.format(stream_name_map[edge.upstream_node, edge.upstream_label])
def _get_filter_spec(node, outgoing_edge_map, stream_name_map):
incoming_edges = node.incoming_edges
outgoing_edges = get_outgoing_edges(node, outgoing_edge_map)
inputs = [
_format_input_stream_name(stream_name_map, edge) for edge in incoming_edges
]
outputs = [
_format_output_stream_name(stream_name_map, edge) for edge in outgoing_edges
]
filter_spec = '{}{}{}'.format(
''.join(inputs), node._get_filter(outgoing_edges), ''.join(outputs)
)
return filter_spec
def _allocate_filter_stream_names(filter_nodes, outgoing_edge_maps, stream_name_map):
stream_count = 0
for upstream_node in filter_nodes:
outgoing_edge_map = outgoing_edge_maps[upstream_node]
for upstream_label, downstreams in sorted(outgoing_edge_map.items()):
if len(downstreams) > 1:
# TODO: automatically insert `splits` ahead of time via graph transformation.
raise ValueError(
'Encountered {} with multiple outgoing edges with same upstream label {!r}; a '
'`split` filter is probably required'.format(
upstream_node, upstream_label
)
)
stream_name_map[upstream_node, upstream_label] = 's{}'.format(stream_count)
stream_count += 1
def _get_filter_arg(filter_nodes, outgoing_edge_maps, stream_name_map):
_allocate_filter_stream_names(filter_nodes, outgoing_edge_maps, stream_name_map)
filter_specs = [
_get_filter_spec(node, outgoing_edge_maps[node], stream_name_map)
for node in filter_nodes
]
return ';'.join(filter_specs)
def _get_global_args(node):
return list(node.args)
def _get_output_args(node, stream_name_map):
if node.name != output.__name__:
raise ValueError('Unsupported output node: {}'.format(node))
args = []
if len(node.incoming_edges) == 0:
raise ValueError('Output node {} has no mapped streams'.format(node))
for edge in node.incoming_edges:
# edge = node.incoming_edges[0]
stream_name = _format_input_stream_name(
stream_name_map, edge, is_final_arg=True
)
if stream_name != '0' or len(node.incoming_edges) > 1:
args += ['-map', stream_name]
kwargs = copy.copy(node.kwargs)
filename = kwargs.pop('filename')
if 'format' in kwargs:
args += ['-f', kwargs.pop('format')]
if 'video_bitrate' in kwargs:
args += ['-b:v', str(kwargs.pop('video_bitrate'))]
if 'audio_bitrate' in kwargs:
args += ['-b:a', str(kwargs.pop('audio_bitrate'))]
if 'video_size' in kwargs:
video_size = kwargs.pop('video_size')
if not isinstance(video_size, basestring) and isinstance(
video_size, collections.Iterable
):
video_size = '{}x{}'.format(video_size[0], video_size[1])
args += ['-video_size', video_size]
args += convert_kwargs_to_cmd_line_args(kwargs)
args += [filename]
return args
@output_operator()
def get_args(stream_spec, overwrite_output=False):
"""Build command-line arguments to be passed to ffmpeg."""
nodes = get_stream_spec_nodes(stream_spec)
args = []
# TODO: group nodes together, e.g. `-i somefile -r somerate`.
sorted_nodes, outgoing_edge_maps = topo_sort(nodes)
input_nodes = [node for node in sorted_nodes if isinstance(node, InputNode)]
output_nodes = [node for node in sorted_nodes if isinstance(node, OutputNode)]
global_nodes = [node for node in sorted_nodes if isinstance(node, GlobalNode)]
filter_nodes = [node for node in sorted_nodes if isinstance(node, FilterNode)]
stream_name_map = {(node, None): str(i) for i, node in enumerate(input_nodes)}
filter_arg = _get_filter_arg(filter_nodes, outgoing_edge_maps, stream_name_map)
args += reduce(operator.add, [_get_input_args(node) for node in input_nodes])
if filter_arg:
args += ['-filter_complex', filter_arg]
args += reduce(
operator.add, [_get_output_args(node, stream_name_map) for node in output_nodes]
)
args += reduce(operator.add, [_get_global_args(node) for node in global_nodes], [])
if overwrite_output:
args += ['-y']
return args
@output_operator()
def compile(stream_spec, cmd='ffmpeg', overwrite_output=False):
"""Build command-line for invoking ffmpeg.
The :meth:`run` function uses this to build the commnad line
arguments and should work in most cases, but calling this function
directly is useful for debugging or if you need to invoke ffmpeg
manually for whatever reason.
This is the same as calling :meth:`get_args` except that it also
includes the ``ffmpeg`` command as the first argument.
"""
if isinstance(cmd, basestring):
cmd = [cmd]
elif type(cmd) != list:
cmd = list(cmd)
return cmd + get_args(stream_spec, overwrite_output=overwrite_output)
@output_operator()
def run_async(
stream_spec,
cmd='ffmpeg',
pipe_stdin=False,
pipe_stdout=False,
pipe_stderr=False,
quiet=False,
overwrite_output=False,
):
"""Asynchronously invoke ffmpeg for the supplied node graph.
Args:
pipe_stdin: if True, connect pipe to subprocess stdin (to be
used with ``pipe:`` ffmpeg inputs).
pipe_stdout: if True, connect pipe to subprocess stdout (to be
used with ``pipe:`` ffmpeg outputs).
pipe_stderr: if True, connect pipe to subprocess stderr.
quiet: shorthand for setting ``capture_stdout`` and
``capture_stderr``.
**kwargs: keyword-arguments passed to ``get_args()`` (e.g.
``overwrite_output=True``).
Returns:
A `subprocess Popen`_ object representing the child process.
Examples:
Run and stream input::
process = (
ffmpeg
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height))
.output(out_filename, pix_fmt='yuv420p')
.overwrite_output()
.run_async(pipe_stdin=True)
)
process.communicate(input=input_data)
Run and capture output::
process = (
ffmpeg
.input(in_filename)
.output('pipe':, format='rawvideo', pix_fmt='rgb24')
.run_async(pipe_stdout=True, pipe_stderr=True)
)
out, err = process.communicate()
Process video frame-by-frame using numpy::
process1 = (
ffmpeg
.input(in_filename)
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run_async(pipe_stdout=True)
)
process2 = (
ffmpeg
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height))
.output(out_filename, pix_fmt='yuv420p')
.overwrite_output()
.run_async(pipe_stdin=True)
)
while True:
in_bytes = process1.stdout.read(width * height * 3)
if not in_bytes:
break
in_frame = (
np
.frombuffer(in_bytes, np.uint8)
.reshape([height, width, 3])
)
out_frame = in_frame * 0.3
process2.stdin.write(
frame
.astype(np.uint8)
.tobytes()
)
process2.stdin.close()
process1.wait()
process2.wait()
.. _subprocess Popen: https://docs.python.org/3/library/subprocess.html#popen-objects
"""
args = compile(stream_spec, cmd, overwrite_output=overwrite_output)
stdin_stream = subprocess.PIPE if pipe_stdin else None
stdout_stream = subprocess.PIPE if pipe_stdout or quiet else None
stderr_stream = subprocess.PIPE if pipe_stderr or quiet else None
return subprocess.Popen(
args, stdin=stdin_stream, stdout=stdout_stream, stderr=stderr_stream
)
@output_operator()
def run(
stream_spec,
cmd='ffmpeg',
capture_stdout=False,
capture_stderr=False,
input=None,
quiet=False,
overwrite_output=False,
):
"""Invoke ffmpeg for the supplied node graph.
Args:
capture_stdout: if True, capture stdout (to be used with
``pipe:`` ffmpeg outputs).
capture_stderr: if True, capture stderr.
quiet: shorthand for setting ``capture_stdout`` and ``capture_stderr``.
input: text to be sent to stdin (to be used with ``pipe:``
ffmpeg inputs)
**kwargs: keyword-arguments passed to ``get_args()`` (e.g.
``overwrite_output=True``).
Returns: (out, err) tuple containing captured stdout and stderr data.
"""
process = run_async(
stream_spec,
cmd,
pipe_stdin=input is not None,
pipe_stdout=capture_stdout,
pipe_stderr=capture_stderr,
quiet=quiet,
overwrite_output=overwrite_output,
)
out, err = process.communicate(input)
retcode = process.poll()
if retcode:
raise Error('ffmpeg', out, err)
return out, err
__all__ = ['compile', 'Error', 'get_args', 'run', 'run_async']

97
libs/ffmpeg/_utils.py Normal file
View file

@ -0,0 +1,97 @@
from __future__ import unicode_literals
from builtins import str
from past.builtins import basestring
import hashlib
import sys
if sys.version_info.major == 2:
# noinspection PyUnresolvedReferences,PyShadowingBuiltins
str = str
# `past.builtins.basestring` module can't be imported on Python3 in some environments (Ubuntu).
# This code is copy-pasted from it to avoid crashes.
class BaseBaseString(type):
def __instancecheck__(cls, instance):
return isinstance(instance, (bytes, str))
def __subclasshook__(cls, thing):
# TODO: What should go here?
raise NotImplemented
def with_metaclass(meta, *bases):
class metaclass(meta):
__call__ = type.__call__
__init__ = type.__init__
def __new__(cls, name, this_bases, d):
if this_bases is None:
return type.__new__(cls, name, (), d)
return meta(name, bases, d)
return metaclass('temporary_class', None, {})
if sys.version_info.major >= 3:
class basestring(with_metaclass(BaseBaseString)):
pass
else:
# noinspection PyUnresolvedReferences,PyCompatibility
from builtins import basestring
def _recursive_repr(item):
"""Hack around python `repr` to deterministically represent dictionaries.
This is able to represent more things than json.dumps, since it does not require things to be JSON serializable
(e.g. datetimes).
"""
if isinstance(item, basestring):
result = str(item)
elif isinstance(item, list):
result = '[{}]'.format(', '.join([_recursive_repr(x) for x in item]))
elif isinstance(item, dict):
kv_pairs = [
'{}: {}'.format(_recursive_repr(k), _recursive_repr(item[k]))
for k in sorted(item)
]
result = '{' + ', '.join(kv_pairs) + '}'
else:
result = repr(item)
return result
def get_hash(item):
repr_ = _recursive_repr(item).encode('utf-8')
return hashlib.md5(repr_).hexdigest()
def get_hash_int(item):
return int(get_hash(item), base=16)
def escape_chars(text, chars):
"""Helper function to escape uncomfortable characters."""
text = str(text)
chars = list(set(chars))
if '\\' in chars:
chars.remove('\\')
chars.insert(0, '\\')
for ch in chars:
text = text.replace(ch, '\\' + ch)
return text
def convert_kwargs_to_cmd_line_args(kwargs):
"""Helper function to build command line arguments out of dict."""
args = []
for k in sorted(kwargs.keys()):
v = kwargs[k]
args.append('-{}'.format(k))
if v is not None:
args.append('{}'.format(v))
return args

108
libs/ffmpeg/_view.py Normal file
View file

@ -0,0 +1,108 @@
from __future__ import unicode_literals
from builtins import str
from .dag import get_outgoing_edges
from ._run import topo_sort
import tempfile
from ffmpeg.nodes import (
FilterNode,
get_stream_spec_nodes,
InputNode,
OutputNode,
stream_operator,
)
_RIGHT_ARROW = '\u2192'
def _get_node_color(node):
if isinstance(node, InputNode):
color = '#99cc00'
elif isinstance(node, OutputNode):
color = '#99ccff'
elif isinstance(node, FilterNode):
color = '#ffcc00'
else:
color = None
return color
@stream_operator()
def view(stream_spec, detail=False, filename=None, pipe=False, **kwargs):
try:
import graphviz
except ImportError:
raise ImportError(
'failed to import graphviz; please make sure graphviz is installed (e.g. `pip install '
'graphviz`)'
)
show_labels = kwargs.pop('show_labels', True)
if pipe and filename is not None:
raise ValueError('Can\'t specify both `filename` and `pipe`')
elif not pipe and filename is None:
filename = tempfile.mktemp()
nodes = get_stream_spec_nodes(stream_spec)
sorted_nodes, outgoing_edge_maps = topo_sort(nodes)
graph = graphviz.Digraph(format='png')
graph.attr(rankdir='LR')
if len(list(kwargs.keys())) != 0:
raise ValueError(
'Invalid kwargs key(s): {}'.format(', '.join(list(kwargs.keys())))
)
for node in sorted_nodes:
color = _get_node_color(node)
if detail:
lines = [node.short_repr]
lines += ['{!r}'.format(arg) for arg in node.args]
lines += [
'{}={!r}'.format(key, node.kwargs[key]) for key in sorted(node.kwargs)
]
node_text = '\n'.join(lines)
else:
node_text = node.short_repr
graph.node(
str(hash(node)), node_text, shape='box', style='filled', fillcolor=color
)
outgoing_edge_map = outgoing_edge_maps.get(node, {})
for edge in get_outgoing_edges(node, outgoing_edge_map):
kwargs = {}
up_label = edge.upstream_label
down_label = edge.downstream_label
up_selector = edge.upstream_selector
if show_labels and (
up_label is not None
or down_label is not None
or up_selector is not None
):
if up_label is None:
up_label = ''
if up_selector is not None:
up_label += ":" + up_selector
if down_label is None:
down_label = ''
if up_label != '' and down_label != '':
middle = ' {} '.format(_RIGHT_ARROW)
else:
middle = ''
kwargs['label'] = '{} {} {}'.format(up_label, middle, down_label)
upstream_node_id = str(hash(edge.upstream_node))
downstream_node_id = str(hash(edge.downstream_node))
graph.edge(upstream_node_id, downstream_node_id, **kwargs)
if pipe:
return graph.pipe()
else:
graph.view(filename, cleanup=True)
return stream_spec
__all__ = ['view']

231
libs/ffmpeg/dag.py Normal file
View file

@ -0,0 +1,231 @@
from __future__ import unicode_literals
from ._utils import get_hash, get_hash_int
from builtins import object
from collections import namedtuple
class DagNode(object):
"""Node in a directed-acyclic graph (DAG).
Edges:
DagNodes are connected by edges. An edge connects two nodes with a label for each side:
- ``upstream_node``: upstream/parent node
- ``upstream_label``: label on the outgoing side of the upstream node
- ``downstream_node``: downstream/child node
- ``downstream_label``: label on the incoming side of the downstream node
For example, DagNode A may be connected to DagNode B with an edge labelled "foo" on A's side, and "bar" on B's
side:
_____ _____
| | | |
| A >[foo]---[bar]> B |
|_____| |_____|
Edge labels may be integers or strings, and nodes cannot have more than one incoming edge with the same label.
DagNodes may have any number of incoming edges and any number of outgoing edges. DagNodes keep track only of
their incoming edges, but the entire graph structure can be inferred by looking at the furthest downstream
nodes and working backwards.
Hashing:
DagNodes must be hashable, and two nodes are considered to be equivalent if they have the same hash value.
Nodes are immutable, and the hash should remain constant as a result. If a node with new contents is required,
create a new node and throw the old one away.
String representation:
In order for graph visualization tools to show useful information, nodes must be representable as strings. The
``repr`` operator should provide a more or less "full" representation of the node, and the ``short_repr``
property should be a shortened, concise representation.
Again, because nodes are immutable, the string representations should remain constant.
"""
def __hash__(self):
"""Return an integer hash of the node."""
raise NotImplementedError()
def __eq__(self, other):
"""Compare two nodes; implementations should return True if (and only if) hashes match."""
raise NotImplementedError()
def __repr__(self, other):
"""Return a full string representation of the node."""
raise NotImplementedError()
@property
def short_repr(self):
"""Return a partial/concise representation of the node."""
raise NotImplementedError()
@property
def incoming_edge_map(self):
"""Provides information about all incoming edges that connect to this node.
The edge map is a dictionary that maps an ``incoming_label`` to ``(outgoing_node, outgoing_label)``. Note that
implicity, ``incoming_node`` is ``self``. See "Edges" section above.
"""
raise NotImplementedError()
DagEdge = namedtuple(
'DagEdge',
[
'downstream_node',
'downstream_label',
'upstream_node',
'upstream_label',
'upstream_selector',
],
)
def get_incoming_edges(downstream_node, incoming_edge_map):
edges = []
for downstream_label, upstream_info in list(incoming_edge_map.items()):
upstream_node, upstream_label, upstream_selector = upstream_info
edges += [
DagEdge(
downstream_node,
downstream_label,
upstream_node,
upstream_label,
upstream_selector,
)
]
return edges
def get_outgoing_edges(upstream_node, outgoing_edge_map):
edges = []
for upstream_label, downstream_infos in sorted(outgoing_edge_map.items()):
for downstream_info in downstream_infos:
downstream_node, downstream_label, downstream_selector = downstream_info
edges += [
DagEdge(
downstream_node,
downstream_label,
upstream_node,
upstream_label,
downstream_selector,
)
]
return edges
class KwargReprNode(DagNode):
"""A DagNode that can be represented as a set of args+kwargs.
"""
@property
def __upstream_hashes(self):
hashes = []
for downstream_label, upstream_info in list(self.incoming_edge_map.items()):
upstream_node, upstream_label, upstream_selector = upstream_info
hashes += [
hash(x)
for x in [
downstream_label,
upstream_node,
upstream_label,
upstream_selector,
]
]
return hashes
@property
def __inner_hash(self):
props = {'args': self.args, 'kwargs': self.kwargs}
return get_hash(props)
def __get_hash(self):
hashes = self.__upstream_hashes + [self.__inner_hash]
return get_hash_int(hashes)
def __init__(self, incoming_edge_map, name, args, kwargs):
self.__incoming_edge_map = incoming_edge_map
self.name = name
self.args = args
self.kwargs = kwargs
self.__hash = self.__get_hash()
def __hash__(self):
return self.__hash
def __eq__(self, other):
return hash(self) == hash(other)
@property
def short_hash(self):
return '{:x}'.format(abs(hash(self)))[:12]
def long_repr(self, include_hash=True):
formatted_props = ['{!r}'.format(arg) for arg in self.args]
formatted_props += [
'{}={!r}'.format(key, self.kwargs[key]) for key in sorted(self.kwargs)
]
out = '{}({})'.format(self.name, ', '.join(formatted_props))
if include_hash:
out += ' <{}>'.format(self.short_hash)
return out
def __repr__(self):
return self.long_repr()
@property
def incoming_edges(self):
return get_incoming_edges(self, self.incoming_edge_map)
@property
def incoming_edge_map(self):
return self.__incoming_edge_map
@property
def short_repr(self):
return self.name
def topo_sort(downstream_nodes):
marked_nodes = []
sorted_nodes = []
outgoing_edge_maps = {}
def visit(
upstream_node,
upstream_label,
downstream_node,
downstream_label,
downstream_selector=None,
):
if upstream_node in marked_nodes:
raise RuntimeError('Graph is not a DAG')
if downstream_node is not None:
outgoing_edge_map = outgoing_edge_maps.get(upstream_node, {})
outgoing_edge_infos = outgoing_edge_map.get(upstream_label, [])
outgoing_edge_infos += [
(downstream_node, downstream_label, downstream_selector)
]
outgoing_edge_map[upstream_label] = outgoing_edge_infos
outgoing_edge_maps[upstream_node] = outgoing_edge_map
if upstream_node not in sorted_nodes:
marked_nodes.append(upstream_node)
for edge in upstream_node.incoming_edges:
visit(
edge.upstream_node,
edge.upstream_label,
edge.downstream_node,
edge.downstream_label,
edge.upstream_selector,
)
marked_nodes.remove(upstream_node)
sorted_nodes.append(upstream_node)
unmarked_nodes = [(node, None) for node in downstream_nodes]
while unmarked_nodes:
upstream_node, upstream_label = unmarked_nodes.pop()
visit(upstream_node, upstream_label, None, None)
return sorted_nodes, outgoing_edge_maps

377
libs/ffmpeg/nodes.py Normal file
View file

@ -0,0 +1,377 @@
from __future__ import unicode_literals
from past.builtins import basestring
from .dag import KwargReprNode
from ._utils import escape_chars, get_hash_int
from builtins import object
import os
def _is_of_types(obj, types):
valid = False
for stream_type in types:
if isinstance(obj, stream_type):
valid = True
break
return valid
def _get_types_str(types):
return ', '.join(['{}.{}'.format(x.__module__, x.__name__) for x in types])
class Stream(object):
"""Represents the outgoing edge of an upstream node; may be used to create more downstream nodes."""
def __init__(
self, upstream_node, upstream_label, node_types, upstream_selector=None
):
if not _is_of_types(upstream_node, node_types):
raise TypeError(
'Expected upstream node to be of one of the following type(s): {}; got {}'.format(
_get_types_str(node_types), type(upstream_node)
)
)
self.node = upstream_node
self.label = upstream_label
self.selector = upstream_selector
def __hash__(self):
return get_hash_int([hash(self.node), hash(self.label)])
def __eq__(self, other):
return hash(self) == hash(other)
def __repr__(self):
node_repr = self.node.long_repr(include_hash=False)
selector = ''
if self.selector:
selector = ':{}'.format(self.selector)
out = '{}[{!r}{}] <{}>'.format(
node_repr, self.label, selector, self.node.short_hash
)
return out
def __getitem__(self, index):
"""
Select a component (audio, video) of the stream.
Example:
Process the audio and video portions of a stream independently::
input = ffmpeg.input('in.mp4')
audio = input['a'].filter("aecho", 0.8, 0.9, 1000, 0.3)
video = input['v'].hflip()
out = ffmpeg.output(audio, video, 'out.mp4')
"""
if self.selector is not None:
raise ValueError('Stream already has a selector: {}'.format(self))
elif not isinstance(index, basestring):
raise TypeError("Expected string index (e.g. 'a'); got {!r}".format(index))
return self.node.stream(label=self.label, selector=index)
@property
def audio(self):
"""Select the audio-portion of a stream.
Some ffmpeg filters drop audio streams, and care must be taken
to preserve the audio in the final output. The ``.audio`` and
``.video`` operators can be used to reference the audio/video
portions of a stream so that they can be processed separately
and then re-combined later in the pipeline. This dilemma is
intrinsic to ffmpeg, and ffmpeg-python tries to stay out of the
way while users may refer to the official ffmpeg documentation
as to why certain filters drop audio.
``stream.audio`` is a shorthand for ``stream['a']``.
Example:
Process the audio and video portions of a stream independently::
input = ffmpeg.input('in.mp4')
audio = input.audio.filter("aecho", 0.8, 0.9, 1000, 0.3)
video = input.video.hflip()
out = ffmpeg.output(audio, video, 'out.mp4')
"""
return self['a']
@property
def video(self):
"""Select the video-portion of a stream.
Some ffmpeg filters drop audio streams, and care must be taken
to preserve the audio in the final output. The ``.audio`` and
``.video`` operators can be used to reference the audio/video
portions of a stream so that they can be processed separately
and then re-combined later in the pipeline. This dilemma is
intrinsic to ffmpeg, and ffmpeg-python tries to stay out of the
way while users may refer to the official ffmpeg documentation
as to why certain filters drop audio.
``stream.video`` is a shorthand for ``stream['v']``.
Example:
Process the audio and video portions of a stream independently::
input = ffmpeg.input('in.mp4')
audio = input.audio.filter("aecho", 0.8, 0.9, 1000, 0.3)
video = input.video.hflip()
out = ffmpeg.output(audio, video, 'out.mp4')
"""
return self['v']
def get_stream_map(stream_spec):
if stream_spec is None:
stream_map = {}
elif isinstance(stream_spec, Stream):
stream_map = {None: stream_spec}
elif isinstance(stream_spec, (list, tuple)):
stream_map = dict(enumerate(stream_spec))
elif isinstance(stream_spec, dict):
stream_map = stream_spec
return stream_map
def get_stream_map_nodes(stream_map):
nodes = []
for stream in list(stream_map.values()):
if not isinstance(stream, Stream):
raise TypeError('Expected Stream; got {}'.format(type(stream)))
nodes.append(stream.node)
return nodes
def get_stream_spec_nodes(stream_spec):
stream_map = get_stream_map(stream_spec)
return get_stream_map_nodes(stream_map)
class Node(KwargReprNode):
"""Node base"""
@classmethod
def __check_input_len(cls, stream_map, min_inputs, max_inputs):
if min_inputs is not None and len(stream_map) < min_inputs:
raise ValueError(
'Expected at least {} input stream(s); got {}'.format(
min_inputs, len(stream_map)
)
)
elif max_inputs is not None and len(stream_map) > max_inputs:
raise ValueError(
'Expected at most {} input stream(s); got {}'.format(
max_inputs, len(stream_map)
)
)
@classmethod
def __check_input_types(cls, stream_map, incoming_stream_types):
for stream in list(stream_map.values()):
if not _is_of_types(stream, incoming_stream_types):
raise TypeError(
'Expected incoming stream(s) to be of one of the following types: {}; got {}'.format(
_get_types_str(incoming_stream_types), type(stream)
)
)
@classmethod
def __get_incoming_edge_map(cls, stream_map):
incoming_edge_map = {}
for downstream_label, upstream in list(stream_map.items()):
incoming_edge_map[downstream_label] = (
upstream.node,
upstream.label,
upstream.selector,
)
return incoming_edge_map
def __init__(
self,
stream_spec,
name,
incoming_stream_types,
outgoing_stream_type,
min_inputs,
max_inputs,
args=[],
kwargs={},
):
stream_map = get_stream_map(stream_spec)
self.__check_input_len(stream_map, min_inputs, max_inputs)
self.__check_input_types(stream_map, incoming_stream_types)
incoming_edge_map = self.__get_incoming_edge_map(stream_map)
super(Node, self).__init__(incoming_edge_map, name, args, kwargs)
self.__outgoing_stream_type = outgoing_stream_type
self.__incoming_stream_types = incoming_stream_types
def stream(self, label=None, selector=None):
"""Create an outgoing stream originating from this node.
More nodes may be attached onto the outgoing stream.
"""
return self.__outgoing_stream_type(self, label, upstream_selector=selector)
def __getitem__(self, item):
"""Create an outgoing stream originating from this node; syntactic sugar for ``self.stream(label)``.
It can also be used to apply a selector: e.g. ``node[0:'a']`` returns a stream with label 0 and
selector ``'a'``, which is the same as ``node.stream(label=0, selector='a')``.
Example:
Process the audio and video portions of a stream independently::
input = ffmpeg.input('in.mp4')
audio = input[:'a'].filter("aecho", 0.8, 0.9, 1000, 0.3)
video = input[:'v'].hflip()
out = ffmpeg.output(audio, video, 'out.mp4')
"""
if isinstance(item, slice):
return self.stream(label=item.start, selector=item.stop)
else:
return self.stream(label=item)
class FilterableStream(Stream):
def __init__(self, upstream_node, upstream_label, upstream_selector=None):
super(FilterableStream, self).__init__(
upstream_node, upstream_label, {InputNode, FilterNode}, upstream_selector
)
# noinspection PyMethodOverriding
class InputNode(Node):
"""InputNode type"""
def __init__(self, name, args=[], kwargs={}):
super(InputNode, self).__init__(
stream_spec=None,
name=name,
incoming_stream_types={},
outgoing_stream_type=FilterableStream,
min_inputs=0,
max_inputs=0,
args=args,
kwargs=kwargs,
)
@property
def short_repr(self):
return os.path.basename(self.kwargs['filename'])
# noinspection PyMethodOverriding
class FilterNode(Node):
def __init__(self, stream_spec, name, max_inputs=1, args=[], kwargs={}):
super(FilterNode, self).__init__(
stream_spec=stream_spec,
name=name,
incoming_stream_types={FilterableStream},
outgoing_stream_type=FilterableStream,
min_inputs=1,
max_inputs=max_inputs,
args=args,
kwargs=kwargs,
)
"""FilterNode"""
def _get_filter(self, outgoing_edges):
args = self.args
kwargs = self.kwargs
if self.name in ('split', 'asplit'):
args = [len(outgoing_edges)]
out_args = [escape_chars(x, '\\\'=:') for x in args]
out_kwargs = {}
for k, v in list(kwargs.items()):
k = escape_chars(k, '\\\'=:')
v = escape_chars(v, '\\\'=:')
out_kwargs[k] = v
arg_params = [escape_chars(v, '\\\'=:') for v in out_args]
kwarg_params = ['{}={}'.format(k, out_kwargs[k]) for k in sorted(out_kwargs)]
params = arg_params + kwarg_params
params_text = escape_chars(self.name, '\\\'=:')
if params:
params_text += '={}'.format(':'.join(params))
return escape_chars(params_text, '\\\'[],;')
# noinspection PyMethodOverriding
class OutputNode(Node):
def __init__(self, stream, name, args=[], kwargs={}):
super(OutputNode, self).__init__(
stream_spec=stream,
name=name,
incoming_stream_types={FilterableStream},
outgoing_stream_type=OutputStream,
min_inputs=1,
max_inputs=None,
args=args,
kwargs=kwargs,
)
@property
def short_repr(self):
return os.path.basename(self.kwargs['filename'])
class OutputStream(Stream):
def __init__(self, upstream_node, upstream_label, upstream_selector=None):
super(OutputStream, self).__init__(
upstream_node,
upstream_label,
{OutputNode, GlobalNode, MergeOutputsNode},
upstream_selector=upstream_selector,
)
# noinspection PyMethodOverriding
class MergeOutputsNode(Node):
def __init__(self, streams, name):
super(MergeOutputsNode, self).__init__(
stream_spec=streams,
name=name,
incoming_stream_types={OutputStream},
outgoing_stream_type=OutputStream,
min_inputs=1,
max_inputs=None,
)
# noinspection PyMethodOverriding
class GlobalNode(Node):
def __init__(self, stream, name, args=[], kwargs={}):
super(GlobalNode, self).__init__(
stream_spec=stream,
name=name,
incoming_stream_types={OutputStream},
outgoing_stream_type=OutputStream,
min_inputs=1,
max_inputs=1,
args=args,
kwargs=kwargs,
)
def stream_operator(stream_classes={Stream}, name=None):
def decorator(func):
func_name = name or func.__name__
[setattr(stream_class, func_name, func) for stream_class in stream_classes]
return func
return decorator
def filter_operator(name=None):
return stream_operator(stream_classes={FilterableStream}, name=name)
def output_operator(name=None):
return stream_operator(stream_classes={OutputStream}, name=name)
__all__ = ['Stream']

View file

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

View file

@ -0,0 +1,771 @@
from __future__ import unicode_literals
from builtins import bytes
from builtins import range
from builtins import str
import ffmpeg
import os
import pytest
import random
import re
import subprocess
try:
import mock # python 2
except ImportError:
from unittest import mock # python 3
TEST_DIR = os.path.dirname(__file__)
SAMPLE_DATA_DIR = os.path.join(TEST_DIR, 'sample_data')
TEST_INPUT_FILE1 = os.path.join(SAMPLE_DATA_DIR, 'in1.mp4')
TEST_OVERLAY_FILE = os.path.join(SAMPLE_DATA_DIR, 'overlay.png')
TEST_OUTPUT_FILE1 = os.path.join(SAMPLE_DATA_DIR, 'out1.mp4')
TEST_OUTPUT_FILE2 = os.path.join(SAMPLE_DATA_DIR, 'out2.mp4')
BOGUS_INPUT_FILE = os.path.join(SAMPLE_DATA_DIR, 'bogus')
subprocess.check_call(['ffmpeg', '-version'])
def test_escape_chars():
assert ffmpeg._utils.escape_chars('a:b', ':') == 'a\:b'
assert ffmpeg._utils.escape_chars('a\\:b', ':\\') == 'a\\\\\\:b'
assert (
ffmpeg._utils.escape_chars('a:b,c[d]e%{}f\'g\'h\\i', '\\\':,[]%')
== 'a\\:b\\,c\\[d\\]e\\%{}f\\\'g\\\'h\\\\i'
)
assert ffmpeg._utils.escape_chars(123, ':\\') == '123'
def test_fluent_equality():
base1 = ffmpeg.input('dummy1.mp4')
base2 = ffmpeg.input('dummy1.mp4')
base3 = ffmpeg.input('dummy2.mp4')
t1 = base1.trim(start_frame=10, end_frame=20)
t2 = base1.trim(start_frame=10, end_frame=20)
t3 = base1.trim(start_frame=10, end_frame=30)
t4 = base2.trim(start_frame=10, end_frame=20)
t5 = base3.trim(start_frame=10, end_frame=20)
assert t1 == t2
assert t1 != t3
assert t1 == t4
assert t1 != t5
def test_fluent_concat():
base = ffmpeg.input('dummy.mp4')
trimmed1 = base.trim(start_frame=10, end_frame=20)
trimmed2 = base.trim(start_frame=30, end_frame=40)
trimmed3 = base.trim(start_frame=50, end_frame=60)
concat1 = ffmpeg.concat(trimmed1, trimmed2, trimmed3)
concat2 = ffmpeg.concat(trimmed1, trimmed2, trimmed3)
concat3 = ffmpeg.concat(trimmed1, trimmed3, trimmed2)
assert concat1 == concat2
assert concat1 != concat3
def test_fluent_output():
ffmpeg.input('dummy.mp4').trim(start_frame=10, end_frame=20).output('dummy2.mp4')
def test_fluent_complex_filter():
in_file = ffmpeg.input('dummy.mp4')
return ffmpeg.concat(
in_file.trim(start_frame=10, end_frame=20),
in_file.trim(start_frame=30, end_frame=40),
in_file.trim(start_frame=50, end_frame=60),
).output('dummy2.mp4')
def test_node_repr():
in_file = ffmpeg.input('dummy.mp4')
trim1 = ffmpeg.trim(in_file, start_frame=10, end_frame=20)
trim2 = ffmpeg.trim(in_file, start_frame=30, end_frame=40)
trim3 = ffmpeg.trim(in_file, start_frame=50, end_frame=60)
concatted = ffmpeg.concat(trim1, trim2, trim3)
output = ffmpeg.output(concatted, 'dummy2.mp4')
assert repr(in_file.node) == 'input(filename={!r}) <{}>'.format(
'dummy.mp4', in_file.node.short_hash
)
assert repr(trim1.node) == 'trim(end_frame=20, start_frame=10) <{}>'.format(
trim1.node.short_hash
)
assert repr(trim2.node) == 'trim(end_frame=40, start_frame=30) <{}>'.format(
trim2.node.short_hash
)
assert repr(trim3.node) == 'trim(end_frame=60, start_frame=50) <{}>'.format(
trim3.node.short_hash
)
assert repr(concatted.node) == 'concat(n=3) <{}>'.format(concatted.node.short_hash)
assert repr(output.node) == 'output(filename={!r}) <{}>'.format(
'dummy2.mp4', output.node.short_hash
)
def test_stream_repr():
in_file = ffmpeg.input('dummy.mp4')
assert repr(in_file) == 'input(filename={!r})[None] <{}>'.format(
'dummy.mp4', in_file.node.short_hash
)
split0 = in_file.filter_multi_output('split')[0]
assert repr(split0) == 'split()[0] <{}>'.format(split0.node.short_hash)
dummy_out = in_file.filter_multi_output('dummy')['out']
assert repr(dummy_out) == 'dummy()[{!r}] <{}>'.format(
dummy_out.label, dummy_out.node.short_hash
)
def test__get_args__simple():
out_file = ffmpeg.input('dummy.mp4').output('dummy2.mp4')
assert out_file.get_args() == ['-i', 'dummy.mp4', 'dummy2.mp4']
def test_global_args():
out_file = (
ffmpeg.input('dummy.mp4')
.output('dummy2.mp4')
.global_args('-progress', 'someurl')
)
assert out_file.get_args() == [
'-i',
'dummy.mp4',
'dummy2.mp4',
'-progress',
'someurl',
]
def _get_simple_example():
return ffmpeg.input(TEST_INPUT_FILE1).output(TEST_OUTPUT_FILE1)
def _get_complex_filter_example():
split = ffmpeg.input(TEST_INPUT_FILE1).vflip().split()
split0 = split[0]
split1 = split[1]
overlay_file = ffmpeg.input(TEST_OVERLAY_FILE)
overlay_file = ffmpeg.crop(overlay_file, 10, 10, 158, 112)
return (
ffmpeg.concat(
split0.trim(start_frame=10, end_frame=20),
split1.trim(start_frame=30, end_frame=40),
)
.overlay(overlay_file.hflip())
.drawbox(50, 50, 120, 120, color='red', thickness=5)
.output(TEST_OUTPUT_FILE1)
.overwrite_output()
)
def test__get_args__complex_filter():
out = _get_complex_filter_example()
args = ffmpeg.get_args(out)
assert args == [
'-i',
TEST_INPUT_FILE1,
'-i',
TEST_OVERLAY_FILE,
'-filter_complex',
'[0]vflip[s0];'
'[s0]split=2[s1][s2];'
'[s1]trim=end_frame=20:start_frame=10[s3];'
'[s2]trim=end_frame=40:start_frame=30[s4];'
'[s3][s4]concat=n=2[s5];'
'[1]crop=158:112:10:10[s6];'
'[s6]hflip[s7];'
'[s5][s7]overlay=eof_action=repeat[s8];'
'[s8]drawbox=50:50:120:120:red:t=5[s9]',
'-map',
'[s9]',
TEST_OUTPUT_FILE1,
'-y',
]
def test_combined_output():
i1 = ffmpeg.input(TEST_INPUT_FILE1)
i2 = ffmpeg.input(TEST_OVERLAY_FILE)
out = ffmpeg.output(i1, i2, TEST_OUTPUT_FILE1)
assert out.get_args() == [
'-i',
TEST_INPUT_FILE1,
'-i',
TEST_OVERLAY_FILE,
'-map',
'0',
'-map',
'1',
TEST_OUTPUT_FILE1,
]
@pytest.mark.parametrize('use_shorthand', [True, False])
def test_filter_with_selector(use_shorthand):
i = ffmpeg.input(TEST_INPUT_FILE1)
if use_shorthand:
v1 = i.video.hflip()
a1 = i.audio.filter('aecho', 0.8, 0.9, 1000, 0.3)
else:
v1 = i['v'].hflip()
a1 = i['a'].filter('aecho', 0.8, 0.9, 1000, 0.3)
out = ffmpeg.output(a1, v1, TEST_OUTPUT_FILE1)
assert out.get_args() == [
'-i',
TEST_INPUT_FILE1,
'-filter_complex',
'[0:a]aecho=0.8:0.9:1000:0.3[s0];' '[0:v]hflip[s1]',
'-map',
'[s0]',
'-map',
'[s1]',
TEST_OUTPUT_FILE1,
]
def test_get_item_with_bad_selectors():
input = ffmpeg.input(TEST_INPUT_FILE1)
with pytest.raises(ValueError) as excinfo:
input['a']['a']
assert str(excinfo.value).startswith('Stream already has a selector:')
with pytest.raises(TypeError) as excinfo:
input[:'a']
assert str(excinfo.value).startswith("Expected string index (e.g. 'a')")
with pytest.raises(TypeError) as excinfo:
input[5]
assert str(excinfo.value).startswith("Expected string index (e.g. 'a')")
def _get_complex_filter_asplit_example():
split = ffmpeg.input(TEST_INPUT_FILE1).vflip().asplit()
split0 = split[0]
split1 = split[1]
return (
ffmpeg.concat(
split0.filter('atrim', start=10, end=20),
split1.filter('atrim', start=30, end=40),
)
.output(TEST_OUTPUT_FILE1)
.overwrite_output()
)
def test_filter_concat__video_only():
in1 = ffmpeg.input('in1.mp4')
in2 = ffmpeg.input('in2.mp4')
args = ffmpeg.concat(in1, in2).output('out.mp4').get_args()
assert args == [
'-i',
'in1.mp4',
'-i',
'in2.mp4',
'-filter_complex',
'[0][1]concat=n=2[s0]',
'-map',
'[s0]',
'out.mp4',
]
def test_filter_concat__audio_only():
in1 = ffmpeg.input('in1.mp4')
in2 = ffmpeg.input('in2.mp4')
args = ffmpeg.concat(in1, in2, v=0, a=1).output('out.mp4').get_args()
assert args == [
'-i',
'in1.mp4',
'-i',
'in2.mp4',
'-filter_complex',
'[0][1]concat=a=1:n=2:v=0[s0]',
'-map',
'[s0]',
'out.mp4',
]
def test_filter_concat__audio_video():
in1 = ffmpeg.input('in1.mp4')
in2 = ffmpeg.input('in2.mp4')
joined = ffmpeg.concat(in1.video, in1.audio, in2.hflip(), in2['a'], v=1, a=1).node
args = ffmpeg.output(joined[0], joined[1], 'out.mp4').get_args()
assert args == [
'-i',
'in1.mp4',
'-i',
'in2.mp4',
'-filter_complex',
'[1]hflip[s0];[0:v][0:a][s0][1:a]concat=a=1:n=2:v=1[s1][s2]',
'-map',
'[s1]',
'-map',
'[s2]',
'out.mp4',
]
def test_filter_concat__wrong_stream_count():
in1 = ffmpeg.input('in1.mp4')
in2 = ffmpeg.input('in2.mp4')
with pytest.raises(ValueError) as excinfo:
ffmpeg.concat(in1.video, in1.audio, in2.hflip(), v=1, a=1).node
assert (
str(excinfo.value)
== 'Expected concat input streams to have length multiple of 2 (v=1, a=1); got 3'
)
def test_filter_asplit():
out = _get_complex_filter_asplit_example()
args = out.get_args()
assert args == [
'-i',
TEST_INPUT_FILE1,
'-filter_complex',
'[0]vflip[s0];[s0]asplit=2[s1][s2];[s1]atrim=end=20:start=10[s3];[s2]atrim=end=40:start=30[s4];[s3]'
'[s4]concat=n=2[s5]',
'-map',
'[s5]',
TEST_OUTPUT_FILE1,
'-y',
]
def test__output__bitrate():
args = (
ffmpeg.input('in')
.output('out', video_bitrate=1000, audio_bitrate=200)
.get_args()
)
assert args == ['-i', 'in', '-b:v', '1000', '-b:a', '200', 'out']
@pytest.mark.parametrize('video_size', [(320, 240), '320x240'])
def test__output__video_size(video_size):
args = ffmpeg.input('in').output('out', video_size=video_size).get_args()
assert args == ['-i', 'in', '-video_size', '320x240', 'out']
def test_filter_normal_arg_escape():
"""Test string escaping of normal filter args (e.g. ``font`` param of ``drawtext`` filter)."""
def _get_drawtext_font_repr(font):
"""Build a command-line arg using drawtext ``font`` param and extract the ``-filter_complex`` arg."""
args = (
ffmpeg.input('in')
.drawtext('test', font='a{}b'.format(font))
.output('out')
.get_args()
)
assert args[:3] == ['-i', 'in', '-filter_complex']
assert args[4:] == ['-map', '[s0]', 'out']
match = re.match(
r'\[0\]drawtext=font=a((.|\n)*)b:text=test\[s0\]', args[3], re.MULTILINE
)
assert match is not None, 'Invalid -filter_complex arg: {!r}'.format(args[3])
return match.group(1)
expected_backslash_counts = {
'x': 0,
'\'': 3,
'\\': 3,
'%': 0,
':': 2,
',': 1,
'[': 1,
']': 1,
'=': 2,
'\n': 0,
}
for ch, expected_backslash_count in list(expected_backslash_counts.items()):
expected = '{}{}'.format('\\' * expected_backslash_count, ch)
actual = _get_drawtext_font_repr(ch)
assert expected == actual
def test_filter_text_arg_str_escape():
"""Test string escaping of normal filter args (e.g. ``text`` param of ``drawtext`` filter)."""
def _get_drawtext_text_repr(text):
"""Build a command-line arg using drawtext ``text`` param and extract the ``-filter_complex`` arg."""
args = ffmpeg.input('in').drawtext('a{}b'.format(text)).output('out').get_args()
assert args[:3] == ['-i', 'in', '-filter_complex']
assert args[4:] == ['-map', '[s0]', 'out']
match = re.match(r'\[0\]drawtext=text=a((.|\n)*)b\[s0\]', args[3], re.MULTILINE)
assert match is not None, 'Invalid -filter_complex arg: {!r}'.format(args[3])
return match.group(1)
expected_backslash_counts = {
'x': 0,
'\'': 7,
'\\': 7,
'%': 4,
':': 2,
',': 1,
'[': 1,
']': 1,
'=': 2,
'\n': 0,
}
for ch, expected_backslash_count in list(expected_backslash_counts.items()):
expected = '{}{}'.format('\\' * expected_backslash_count, ch)
actual = _get_drawtext_text_repr(ch)
assert expected == actual
# def test_version():
# subprocess.check_call(['ffmpeg', '-version'])
def test__compile():
out_file = ffmpeg.input('dummy.mp4').output('dummy2.mp4')
assert out_file.compile() == ['ffmpeg', '-i', 'dummy.mp4', 'dummy2.mp4']
assert out_file.compile(cmd='ffmpeg.old') == [
'ffmpeg.old',
'-i',
'dummy.mp4',
'dummy2.mp4',
]
@pytest.mark.parametrize('pipe_stdin', [True, False])
@pytest.mark.parametrize('pipe_stdout', [True, False])
@pytest.mark.parametrize('pipe_stderr', [True, False])
def test__run_async(mocker, pipe_stdin, pipe_stdout, pipe_stderr):
process__mock = mock.Mock()
popen__mock = mocker.patch.object(subprocess, 'Popen', return_value=process__mock)
stream = _get_simple_example()
process = ffmpeg.run_async(
stream, pipe_stdin=pipe_stdin, pipe_stdout=pipe_stdout, pipe_stderr=pipe_stderr
)
assert process is process__mock
expected_stdin = subprocess.PIPE if pipe_stdin else None
expected_stdout = subprocess.PIPE if pipe_stdout else None
expected_stderr = subprocess.PIPE if pipe_stderr else None
(args,), kwargs = popen__mock.call_args
assert args == ffmpeg.compile(stream)
assert kwargs == dict(
stdin=expected_stdin, stdout=expected_stdout, stderr=expected_stderr
)
def test__run():
stream = _get_complex_filter_example()
out, err = ffmpeg.run(stream)
assert out is None
assert err is None
@pytest.mark.parametrize('capture_stdout', [True, False])
@pytest.mark.parametrize('capture_stderr', [True, False])
def test__run__capture_out(mocker, capture_stdout, capture_stderr):
mocker.patch.object(ffmpeg._run, 'compile', return_value=['echo', 'test'])
stream = _get_simple_example()
out, err = ffmpeg.run(
stream, capture_stdout=capture_stdout, capture_stderr=capture_stderr
)
if capture_stdout:
assert out == 'test\n'.encode()
else:
assert out is None
if capture_stderr:
assert err == ''.encode()
else:
assert err is None
def test__run__input_output(mocker):
mocker.patch.object(ffmpeg._run, 'compile', return_value=['cat'])
stream = _get_simple_example()
out, err = ffmpeg.run(stream, input='test'.encode(), capture_stdout=True)
assert out == 'test'.encode()
assert err is None
@pytest.mark.parametrize('capture_stdout', [True, False])
@pytest.mark.parametrize('capture_stderr', [True, False])
def test__run__error(mocker, capture_stdout, capture_stderr):
mocker.patch.object(ffmpeg._run, 'compile', return_value=['ffmpeg'])
stream = _get_complex_filter_example()
with pytest.raises(ffmpeg.Error) as excinfo:
out, err = ffmpeg.run(
stream, capture_stdout=capture_stdout, capture_stderr=capture_stderr
)
assert str(excinfo.value) == 'ffmpeg error (see stderr output for detail)'
out = excinfo.value.stdout
err = excinfo.value.stderr
if capture_stdout:
assert out == ''.encode()
else:
assert out is None
if capture_stderr:
assert err.decode().startswith('ffmpeg version')
else:
assert err is None
def test__run__multi_output():
in_ = ffmpeg.input(TEST_INPUT_FILE1)
out1 = in_.output(TEST_OUTPUT_FILE1)
out2 = in_.output(TEST_OUTPUT_FILE2)
ffmpeg.run([out1, out2], overwrite_output=True)
def test__run__dummy_cmd():
stream = _get_complex_filter_example()
ffmpeg.run(stream, cmd='true')
def test__run__dummy_cmd_list():
stream = _get_complex_filter_example()
ffmpeg.run(stream, cmd=['true', 'ignored'])
def test__filter__custom():
stream = ffmpeg.input('dummy.mp4')
stream = ffmpeg.filter(stream, 'custom_filter', 'a', 'b', kwarg1='c')
stream = ffmpeg.output(stream, 'dummy2.mp4')
assert stream.get_args() == [
'-i',
'dummy.mp4',
'-filter_complex',
'[0]custom_filter=a:b:kwarg1=c[s0]',
'-map',
'[s0]',
'dummy2.mp4',
]
def test__filter__custom_fluent():
stream = (
ffmpeg.input('dummy.mp4')
.filter('custom_filter', 'a', 'b', kwarg1='c')
.output('dummy2.mp4')
)
assert stream.get_args() == [
'-i',
'dummy.mp4',
'-filter_complex',
'[0]custom_filter=a:b:kwarg1=c[s0]',
'-map',
'[s0]',
'dummy2.mp4',
]
def test__merge_outputs():
in_ = ffmpeg.input('in.mp4')
out1 = in_.output('out1.mp4')
out2 = in_.output('out2.mp4')
assert ffmpeg.merge_outputs(out1, out2).get_args() == [
'-i',
'in.mp4',
'out1.mp4',
'out2.mp4',
]
assert ffmpeg.get_args([out1, out2]) == ['-i', 'in.mp4', 'out2.mp4', 'out1.mp4']
def test__input__start_time():
assert ffmpeg.input('in', ss=10.5).output('out').get_args() == [
'-ss',
'10.5',
'-i',
'in',
'out',
]
assert ffmpeg.input('in', ss=0.0).output('out').get_args() == [
'-ss',
'0.0',
'-i',
'in',
'out',
]
def test_multi_passthrough():
out1 = ffmpeg.input('in1.mp4').output('out1.mp4')
out2 = ffmpeg.input('in2.mp4').output('out2.mp4')
out = ffmpeg.merge_outputs(out1, out2)
assert ffmpeg.get_args(out) == [
'-i',
'in1.mp4',
'-i',
'in2.mp4',
'out1.mp4',
'-map',
'1',
'out2.mp4',
]
assert ffmpeg.get_args([out1, out2]) == [
'-i',
'in2.mp4',
'-i',
'in1.mp4',
'out2.mp4',
'-map',
'1',
'out1.mp4',
]
def test_passthrough_selectors():
i1 = ffmpeg.input(TEST_INPUT_FILE1)
args = ffmpeg.output(i1['1'], i1['2'], TEST_OUTPUT_FILE1).get_args()
assert args == [
'-i',
TEST_INPUT_FILE1,
'-map',
'0:1',
'-map',
'0:2',
TEST_OUTPUT_FILE1,
]
def test_mixed_passthrough_selectors():
i1 = ffmpeg.input(TEST_INPUT_FILE1)
args = ffmpeg.output(i1['1'].hflip(), i1['2'], TEST_OUTPUT_FILE1).get_args()
assert args == [
'-i',
TEST_INPUT_FILE1,
'-filter_complex',
'[0:1]hflip[s0]',
'-map',
'[s0]',
'-map',
'0:2',
TEST_OUTPUT_FILE1,
]
def test_pipe():
width = 32
height = 32
frame_size = width * height * 3 # 3 bytes for rgb24
frame_count = 10
start_frame = 2
out = (
ffmpeg.input(
'pipe:0',
format='rawvideo',
pixel_format='rgb24',
video_size=(width, height),
framerate=10,
)
.trim(start_frame=start_frame)
.output('pipe:1', format='rawvideo')
)
args = out.get_args()
assert args == [
'-f',
'rawvideo',
'-video_size',
'{}x{}'.format(width, height),
'-framerate',
'10',
'-pixel_format',
'rgb24',
'-i',
'pipe:0',
'-filter_complex',
'[0]trim=start_frame=2[s0]',
'-map',
'[s0]',
'-f',
'rawvideo',
'pipe:1',
]
cmd = ['ffmpeg'] + args
p = subprocess.Popen(
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
in_data = bytes(
bytearray([random.randint(0, 255) for _ in range(frame_size * frame_count)])
)
p.stdin.write(in_data) # note: this could block, in which case need to use threads
p.stdin.close()
out_data = p.stdout.read()
assert len(out_data) == frame_size * (frame_count - start_frame)
assert out_data == in_data[start_frame * frame_size :]
def test__probe():
data = ffmpeg.probe(TEST_INPUT_FILE1)
assert set(data.keys()) == {'format', 'streams'}
assert data['format']['duration'] == '7.036000'
def test__probe__exception():
with pytest.raises(ffmpeg.Error) as excinfo:
ffmpeg.probe(BOGUS_INPUT_FILE)
assert str(excinfo.value) == 'ffprobe error (see stderr output for detail)'
assert 'No such file or directory'.encode() in excinfo.value.stderr
def test__probe__extra_args():
data = ffmpeg.probe(TEST_INPUT_FILE1, show_frames=None)
assert set(data.keys()) == {'format', 'streams', 'frames'}
def get_filter_complex_input(flt, name):
m = re.search(r'\[([^]]+)\]{}(?=[[;]|$)'.format(name), flt)
if m:
return m.group(1)
else:
return None
def get_filter_complex_outputs(flt, name):
m = re.search(r'(^|[];]){}((\[[^]]+\])+)(?=;|$)'.format(name), flt)
if m:
return m.group(2)[1:-1].split('][')
else:
return None
def test__get_filter_complex_input():
assert get_filter_complex_input("", "scale") is None
assert get_filter_complex_input("scale", "scale") is None
assert get_filter_complex_input("scale[s3][s4];etc", "scale") is None
assert get_filter_complex_input("[s2]scale", "scale") == "s2"
assert get_filter_complex_input("[s2]scale;etc", "scale") == "s2"
assert get_filter_complex_input("[s2]scale[s3][s4];etc", "scale") == "s2"
def test__get_filter_complex_outputs():
assert get_filter_complex_outputs("", "scale") is None
assert get_filter_complex_outputs("scale", "scale") is None
assert get_filter_complex_outputs("scalex[s0][s1]", "scale") is None
assert get_filter_complex_outputs("scale[s0][s1]", "scale") == ['s0', 's1']
assert get_filter_complex_outputs("[s5]scale[s0][s1]", "scale") == ['s0', 's1']
assert get_filter_complex_outputs("[s5]scale[s1][s0]", "scale") == ['s1', 's0']
assert get_filter_complex_outputs("[s5]scale[s1]", "scale") == ['s1']
assert get_filter_complex_outputs("[s5]scale[s1];x", "scale") == ['s1']
assert get_filter_complex_outputs("y;[s5]scale[s1];x", "scale") == ['s1']
def test__multi_output_edge_label_order():
scale2ref = ffmpeg.filter_multi_output(
[ffmpeg.input('x'), ffmpeg.input('y')], 'scale2ref'
)
out = ffmpeg.merge_outputs(
scale2ref[1].filter('scale').output('a'),
scale2ref[10000].filter('hflip').output('b'),
)
args = out.get_args()
flt_cmpl = args[args.index('-filter_complex') + 1]
out1, out2 = get_filter_complex_outputs(flt_cmpl, 'scale2ref')
assert out1 == get_filter_complex_input(flt_cmpl, 'scale')
assert out2 == get_filter_complex_input(flt_cmpl, 'hflip')

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
from .version import __version__ # noqa
from .ffsubsync import main # noqa

View file

@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
import logging
import math
import numpy as np
from .sklearn_shim import TransformerMixin
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class FailedToFindAlignmentException(Exception):
pass
class FFTAligner(TransformerMixin):
def __init__(self):
self.best_offset_ = None
self.best_score_ = None
self.get_score_ = False
def fit(self, refstring, substring, get_score=False):
refstring, substring = [
list(map(int, s))
if isinstance(s, str) else s
for s in [refstring, substring]
]
refstring, substring = map(
lambda s: 2 * np.array(s).astype(float) - 1, [refstring, substring])
total_bits = math.log(len(substring) + len(refstring), 2)
total_length = int(2 ** math.ceil(total_bits))
extra_zeros = total_length - len(substring) - len(refstring)
subft = np.fft.fft(np.append(np.zeros(extra_zeros + len(refstring)), substring))
refft = np.fft.fft(np.flip(np.append(refstring, np.zeros(len(substring) + extra_zeros)), 0))
convolve = np.real(np.fft.ifft(subft * refft))
best_idx = np.argmax(convolve)
self.best_offset_ = len(convolve) - 1 - best_idx - len(substring)
self.best_score_ = convolve[best_idx]
self.get_score_ = get_score
return self
def transform(self, *_):
if self.get_score_:
return self.best_score_, self.best_offset_
else:
return self.best_offset_
class MaxScoreAligner(TransformerMixin):
def __init__(self, base_aligner, sample_rate=None, max_offset_seconds=None):
if isinstance(base_aligner, type):
self.base_aligner = base_aligner()
else:
self.base_aligner = base_aligner
self.max_offset_seconds = max_offset_seconds
if sample_rate is None or max_offset_seconds is None:
self.max_offset_samples = None
else:
self.max_offset_samples = abs(max_offset_seconds * sample_rate)
self._scores = []
def fit(self, refstring, subpipes):
if not isinstance(subpipes, list):
subpipes = [subpipes]
for subpipe in subpipes:
if hasattr(subpipe, 'transform'):
substring = subpipe.transform(None)
else:
substring = subpipe
self._scores.append((
self.base_aligner.fit_transform(
refstring, substring, get_score=True
),
subpipe
))
return self
def transform(self, *_):
scores = self._scores
if self.max_offset_samples is not None:
scores = list(filter(lambda s: abs(s[0][1]) <= self.max_offset_samples, scores))
if len(scores) == 0:
raise FailedToFindAlignmentException('Synchronization failed; consider passing '
'--max-offset-seconds with a number larger than '
'{}'.format(self.max_offset_seconds))
(score, offset), subpipe = max(scores, key=lambda x: x[0][0])
return offset, subpipe

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
SUBSYNC_RESOURCES_ENV_MAGIC = "ffsubsync_resources_xj48gjdkl340"
SAMPLE_RATE = 100
FRAMERATE_RATIOS = [24./23.976, 25./23.976, 25./24.]
DEFAULT_FRAME_RATE = 48000
DEFAULT_ENCODING = 'infer'
DEFAULT_MAX_SUBTITLE_SECONDS = 10
DEFAULT_START_SECONDS = 0
DEFAULT_SCALE_FACTOR = 1
DEFAULT_VAD = 'subs_then_webrtc'
DEFAULT_MAX_OFFSET_SECONDS = 600
SUBTITLE_EXTENSIONS = ('srt', 'ass', 'ssa')
GITHUB_DEV_USER = 'smacke'
PROJECT_NAME = 'FFsubsync'
PROJECT_LICENSE = 'MIT'
COPYRIGHT_YEAR = '2019'
GITHUB_REPO = 'ffsubsync'
DESCRIPTION = 'Synchronize subtitles with video.'
LONG_DESCRIPTION = 'Automatic and language-agnostic synchronization of subtitles with video.'
WEBSITE = 'https://github.com/{}/{}/'.format(GITHUB_DEV_USER, GITHUB_REPO)
DEV_WEBSITE = 'https://smacke.net/'
# No trailing slash important for this one...
API_RELEASE_URL = 'https://api.github.com/repos/{}/{}/releases/latest'.format(GITHUB_DEV_USER, GITHUB_REPO)
RELEASE_URL = 'https://github.com/{}/{}/releases/latest/'.format(GITHUB_DEV_USER, GITHUB_REPO)

265
libs/ffsubsync/ffsubsync.py Normal file
View file

@ -0,0 +1,265 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
from datetime import datetime
import logging
import os
import shutil
import sys
import numpy as np
from .sklearn_shim import Pipeline
from .aligners import FFTAligner, MaxScoreAligner, FailedToFindAlignmentException
from .constants import *
from .speech_transformers import (
VideoSpeechTransformer,
DeserializeSpeechTransformer,
make_subtitle_speech_pipeline
)
from .subtitle_parser import make_subtitle_parser
from .subtitle_transformers import SubtitleMerger, SubtitleShifter
from .version import __version__
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def override(args, **kwargs):
args_dict = dict(args.__dict__)
args_dict.update(kwargs)
return args_dict
def run(args):
retval = 0
if args.vlc_mode:
logger.setLevel(logging.CRITICAL)
if args.make_test_case and not args.gui_mode: # this validation not necessary for gui mode
if args.srtin is None or args.srtout is None:
logger.error('need to specify input and output srt files for test cases')
return 1
if args.overwrite_input:
if args.srtin is None:
logger.error('need to specify input srt if --overwrite-input is specified since we cannot overwrite stdin')
return 1
if args.srtout is not None:
logger.error('overwrite input set but output file specified; refusing to run in case this was not intended')
return 1
args.srtout = args.srtin
if args.gui_mode and args.srtout is None:
args.srtout = '{}.synced.srt'.format(args.srtin[:-4])
ref_format = args.reference[-3:]
if args.merge_with_reference and ref_format not in SUBTITLE_EXTENSIONS:
logger.error('merging synced output with reference only valid '
'when reference composed of subtitles')
return 1
if args.make_test_case:
handler = logging.FileHandler('ffsubsync.log')
logger.addHandler(handler)
if ref_format in SUBTITLE_EXTENSIONS:
if args.vad is not None:
logger.warning('Vad specified, but reference was not a movie')
reference_pipe = make_subtitle_speech_pipeline(
fmt=ref_format,
**override(
args,
encoding=args.reference_encoding or DEFAULT_ENCODING
)
)
elif ref_format in ('npy', 'npz'):
if args.vad is not None:
logger.warning('Vad specified, but reference was not a movie')
reference_pipe = Pipeline([
('deserialize', DeserializeSpeechTransformer())
])
else:
vad = args.vad or DEFAULT_VAD
if args.reference_encoding is not None:
logger.warning('Reference srt encoding specified, but reference was a video file')
ref_stream = args.reference_stream
if ref_stream is not None and not ref_stream.startswith('0:'):
ref_stream = '0:' + ref_stream
reference_pipe = Pipeline([
('speech_extract', VideoSpeechTransformer(vad=vad,
sample_rate=SAMPLE_RATE,
frame_rate=args.frame_rate,
start_seconds=args.start_seconds,
ffmpeg_path=args.ffmpeg_path,
ref_stream=ref_stream,
vlc_mode=args.vlc_mode,
gui_mode=args.gui_mode))
])
if args.no_fix_framerate:
framerate_ratios = [1.]
else:
framerate_ratios = np.concatenate([
[1.], np.array(FRAMERATE_RATIOS), 1./np.array(FRAMERATE_RATIOS)
])
logger.info("extracting speech segments from reference '%s'...", args.reference)
reference_pipe.fit(args.reference)
logger.info('...done')
npy_savename = None
if args.make_test_case or args.serialize_speech:
logger.info('serializing speech...')
npy_savename = os.path.splitext(args.reference)[0] + '.npz'
np.savez_compressed(npy_savename, speech=reference_pipe.transform(args.reference))
logger.info('...done')
if args.srtin is None:
logger.info('unsynchronized subtitle file not specified; skipping synchronization')
return retval
parser = make_subtitle_parser(fmt=args.srtin[-3:], caching=True, **args.__dict__)
logger.info("extracting speech segments from subtitles '%s'...", args.srtin)
srt_pipes = [
make_subtitle_speech_pipeline(
**override(args, scale_factor=scale_factor, parser=parser)
).fit(args.srtin)
for scale_factor in framerate_ratios
]
logger.info('...done')
logger.info('computing alignments...')
max_offset_seconds = args.max_offset_seconds
try:
sync_was_successful = True
offset_samples, best_srt_pipe = MaxScoreAligner(
FFTAligner, SAMPLE_RATE, max_offset_seconds
).fit_transform(
reference_pipe.transform(args.reference),
srt_pipes,
)
logger.info('...done')
offset_seconds = offset_samples / float(SAMPLE_RATE)
scale_step = best_srt_pipe.named_steps['scale']
logger.info('offset seconds: %.3f', offset_seconds)
logger.info('framerate scale factor: %.3f', scale_step.scale_factor)
output_steps = [('shift', SubtitleShifter(offset_seconds))]
if args.merge_with_reference:
output_steps.append(
('merge',
SubtitleMerger(reference_pipe.named_steps['parse'].subs_))
)
output_pipe = Pipeline(output_steps)
out_subs = output_pipe.fit_transform(scale_step.subs_)
if args.output_encoding != 'same':
out_subs = out_subs.set_encoding(args.output_encoding)
logger.info('writing output to {}'.format(args.srtout or 'stdout'))
out_subs.write_file(args.srtout)
except FailedToFindAlignmentException as e:
sync_was_successful = False
logger.error(e)
if args.make_test_case:
if npy_savename is None:
raise ValueError('need non-null npy_savename')
tar_dir = '{}.{}'.format(
args.reference,
datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
)
logger.info('creating test archive {}.tar.gz...'.format(tar_dir))
os.mkdir(tar_dir)
try:
shutil.move('ffsubsync.log', tar_dir)
shutil.copy(args.srtin, tar_dir)
if sync_was_successful:
shutil.move(args.srtout, tar_dir)
if ref_format in SUBTITLE_EXTENSIONS:
shutil.copy(args.reference, tar_dir)
elif args.serialize_speech or args.reference == npy_savename:
shutil.copy(npy_savename, tar_dir)
else:
shutil.move(npy_savename, tar_dir)
supported_formats = set(list(zip(*shutil.get_archive_formats()))[0])
preferred_formats = ['gztar', 'bztar', 'xztar', 'zip', 'tar']
for archive_format in preferred_formats:
if archive_format in supported_formats:
shutil.make_archive(tar_dir, 'gztar', os.curdir, tar_dir)
break
else:
logger.error('failed to create test archive; no formats supported '
'(this should not happen)')
retval = 1
logger.info('...done')
finally:
shutil.rmtree(tar_dir)
return retval
def add_main_args_for_cli(parser):
parser.add_argument(
'reference',
help='Reference (video, subtitles, or a numpy array with VAD speech) to which to synchronize input subtitles.'
)
parser.add_argument('-i', '--srtin', help='Input subtitles file (default=stdin).')
parser.add_argument('-o', '--srtout', help='Output subtitles file (default=stdout).')
parser.add_argument('--merge-with-reference', '--merge', action='store_true',
help='Merge reference subtitles with synced output subtitles.')
parser.add_argument('--make-test-case', '--create-test-case', action='store_true',
help='If specified, serialize reference speech to a numpy array, '
'and create an archive with input/output subtitles '
'and serialized speech.')
def add_cli_only_args(parser):
parser.add_argument('-v', '--version', action='version',
version='%(prog)s {version}'.format(version=__version__))
parser.add_argument('--overwrite-input', action='store_true',
help='If specified, will overwrite the input srt instead of writing the output to a new file.')
parser.add_argument('--encoding', default=DEFAULT_ENCODING,
help='What encoding to use for reading input subtitles '
'(default=%s).' % DEFAULT_ENCODING)
parser.add_argument('--max-subtitle-seconds', type=float, default=DEFAULT_MAX_SUBTITLE_SECONDS,
help='Maximum duration for a subtitle to appear on-screen '
'(default=%.3f seconds).' % DEFAULT_MAX_SUBTITLE_SECONDS)
parser.add_argument('--start-seconds', type=int, default=DEFAULT_START_SECONDS,
help='Start time for processing '
'(default=%d seconds).' % DEFAULT_START_SECONDS)
parser.add_argument('--max-offset-seconds', type=int, default=DEFAULT_MAX_OFFSET_SECONDS,
help='The max allowed offset seconds for any subtitle segment '
'(default=%d seconds).' % DEFAULT_MAX_OFFSET_SECONDS)
parser.add_argument('--frame-rate', type=int, default=DEFAULT_FRAME_RATE,
help='Frame rate for audio extraction (default=%d).' % DEFAULT_FRAME_RATE)
parser.add_argument('--output-encoding', default='utf-8',
help='What encoding to use for writing output subtitles '
'(default=utf-8). Can indicate "same" to use same '
'encoding as that of the input.')
parser.add_argument('--reference-encoding',
help='What encoding to use for reading / writing reference subtitles '
'(if applicable, default=infer).')
parser.add_argument('--vad', choices=['subs_then_webrtc', 'webrtc', 'subs_then_auditok', 'auditok'],
default=None,
help='Which voice activity detector to use for speech extraction '
'(if using video / audio as a reference, default={}).'.format(DEFAULT_VAD))
parser.add_argument('--no-fix-framerate', action='store_true',
help='If specified, subsync will not attempt to correct a framerate '
'mismatch between reference and subtitles.')
parser.add_argument('--serialize-speech', action='store_true',
help='If specified, serialize reference speech to a numpy array.')
parser.add_argument(
'--reference-stream', '--refstream', '--reference-track', '--reftrack',
default=None,
help='Which stream/track in the video file to use as reference, '
'formatted according to ffmpeg conventions. For example, s:0 '
'uses the first subtitle track; a:3 would use the third audio track.'
)
parser.add_argument(
'--ffmpeg-path', '--ffmpegpath', default=None,
help='Where to look for ffmpeg and ffprobe. Uses the system PATH by default.'
)
parser.add_argument('--vlc-mode', action='store_true', help=argparse.SUPPRESS)
parser.add_argument('--gui-mode', action='store_true', help=argparse.SUPPRESS)
def make_parser():
parser = argparse.ArgumentParser(description='Synchronize subtitles with video.')
add_main_args_for_cli(parser)
add_cli_only_args(parser)
return parser
def main():
parser = make_parser()
args = parser.parse_args()
return run(args)
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import os
import sys
from gooey import Gooey, GooeyParser
from .constants import (
RELEASE_URL,
WEBSITE,
DEV_WEBSITE,
DESCRIPTION,
LONG_DESCRIPTION,
PROJECT_NAME,
PROJECT_LICENSE,
COPYRIGHT_YEAR,
SUBSYNC_RESOURCES_ENV_MAGIC,
)
from .ffsubsync import run, add_cli_only_args
from .version import __version__, update_available
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
_menu = [
{
'name': 'File',
'items': [
{
'type': 'AboutDialog',
'menuTitle': 'About',
'name': PROJECT_NAME,
'description': LONG_DESCRIPTION,
'version': __version__,
'copyright': COPYRIGHT_YEAR,
'website': WEBSITE,
'developer': DEV_WEBSITE,
'license': PROJECT_LICENSE,
},
{
'type': 'Link',
'menuTitle': 'Download latest release',
'url': RELEASE_URL,
}
]
}
]
# set the env magic so that we look for resources in the right place
if SUBSYNC_RESOURCES_ENV_MAGIC not in os.environ:
os.environ[SUBSYNC_RESOURCES_ENV_MAGIC] = getattr(sys, '_MEIPASS', '')
@Gooey(
program_name=PROJECT_NAME,
image_dir=os.path.join(os.environ[SUBSYNC_RESOURCES_ENV_MAGIC], 'img'),
menu=_menu,
tabbed_groups=True,
progress_regex=r"(\d+)%",
hide_progress_msg=True
)
def make_parser():
description = DESCRIPTION
if update_available():
description += '\nUpdate available! Please go to "File" -> "Download latest release" to update FFsubsync.'
parser = GooeyParser(description=description)
main_group = parser.add_argument_group('Basic')
main_group.add_argument(
'reference',
help='Reference (video or subtitles file) to which to synchronize input subtitles.',
widget='FileChooser'
)
main_group.add_argument('srtin', help='Input subtitles file', widget='FileChooser')
main_group.add_argument('-o', '--srtout',
help='Output subtitles file (default=${srtin}.synced.srt).',
widget='FileSaver')
advanced_group = parser.add_argument_group('Advanced')
# TODO: these are shared between gui and cli; don't duplicate this code
advanced_group.add_argument('--merge-with-reference', '--merge', action='store_true',
help='Merge reference subtitles with synced output subtitles.')
advanced_group.add_argument('--make-test-case', '--create-test-case', action='store_true',
help='If specified, create a test archive a few KiB in size '
'to send to the developer as a debugging aid.')
advanced_group.add_argument(
'--reference-stream', '--refstream', '--reference-track', '--reftrack', default=None,
help='Which stream/track in the video file to use as reference, '
'formatted according to ffmpeg conventions. For example, s:0 '
'uses the first subtitle track; a:3 would use the fourth audio track.'
)
return parser
def main():
parser = make_parser()
_ = parser.parse_args() # Fool Gooey into presenting the simpler menu
add_cli_only_args(parser)
args = parser.parse_args()
args.gui_mode = True
return run(args)
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
import six
import sys
class open_file(object):
"""
Context manager that opens a filename and closes it on exit, but does
nothing for file-like objects.
"""
def __init__(self, filename, *args, **kwargs):
self.closing = kwargs.pop('closing', False)
if filename is None:
stream = sys.stdout if 'w' in args else sys.stdin
if six.PY3:
self.closeable = open(stream.fileno(), *args, **kwargs)
self.fh = self.closeable.buffer
else:
self.closeable = stream
self.fh = self.closeable
elif isinstance(filename, six.string_types):
self.fh = open(filename, *args, **kwargs)
self.closeable = self.fh
self.closing = True
else:
self.fh = filename
def __enter__(self):
return self.fh
def __exit__(self, exc_type, exc_val, exc_tb):
if self.closing:
self.closeable.close()
return False

View file

@ -0,0 +1,140 @@
# -*- coding: utf-8 -*-
import copy
from datetime import timedelta
import logging
import pysubs2
import srt
import six
import sys
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SubsMixin(object):
def __init__(self, subs=None):
self.subs_ = subs
def set_encoding(self, encoding):
self.subs_.set_encoding(encoding)
return self
class GenericSubtitle(object):
def __init__(self, start, end, inner):
self.start = start
self.end = end
self.inner = inner
def __eq__(self, other):
eq = True
eq = eq and self.start == other.start
eq = eq and self.end == other.end
eq = eq and self.inner == other.inner
return eq
def resolve_inner_timestamps(self):
ret = copy.deepcopy(self.inner)
if isinstance(self.inner, srt.Subtitle):
ret.start = self.start
ret.end = self.end
elif isinstance(self.inner, pysubs2.SSAEvent):
ret.start = pysubs2.make_time(s=self.start.total_seconds())
ret.end = pysubs2.make_time(s=self.end.total_seconds())
else:
raise NotImplementedError('unsupported subtitle type: %s' % type(self.inner))
return ret
def merge_with(self, other):
assert isinstance(self.inner, type(other.inner))
inner_merged = copy.deepcopy(self.inner)
if isinstance(self.inner, srt.Subtitle):
inner_merged.content = u'{}\n{}'.format(inner_merged.content, other.inner.content)
return self.__class__(
self.start,
self.end,
inner_merged
)
else:
raise NotImplementedError('unsupported subtitle type: %s' % type(self.inner))
@classmethod
def wrap_inner_subtitle(cls, sub):
if isinstance(sub, srt.Subtitle):
return cls(sub.start, sub.end, sub)
elif isinstance(sub, pysubs2.SSAEvent):
return cls(
timedelta(milliseconds=sub.start),
timedelta(milliseconds=sub.end),
sub
)
else:
raise NotImplementedError('unsupported subtitle type: %s' % type(sub))
class GenericSubtitlesFile(object):
def __init__(self, subs, *args, **kwargs):
sub_format = kwargs.pop('sub_format', None)
if sub_format is None:
raise ValueError('format must be specified')
encoding = kwargs.pop('encoding', None)
if encoding is None:
raise ValueError('encoding must be specified')
self.subs_ = subs
self._sub_format = sub_format
self._encoding = encoding
def set_encoding(self, encoding):
if encoding != 'same':
self._encoding = encoding
return self
def __len__(self):
return len(self.subs_)
def __getitem__(self, item):
return self.subs_[item]
@property
def sub_format(self):
return self._sub_format
@property
def encoding(self):
return self._encoding
def gen_raw_resolved_subs(self):
for sub in self.subs_:
yield sub.resolve_inner_timestamps()
def offset(self, td):
offset_subs = []
for sub in self.subs_:
offset_subs.append(
GenericSubtitle(sub.start + td, sub.end + td, sub.inner)
)
return GenericSubtitlesFile(
offset_subs,
sub_format=self.sub_format,
encoding=self.encoding
)
def write_file(self, fname):
subs = list(self.gen_raw_resolved_subs())
if self.sub_format == 'srt':
to_write = srt.compose(subs)
elif self.sub_format in ('ssa', 'ass'):
ssaf = pysubs2.SSAFile()
ssaf.events = subs
to_write = ssaf.to_string(self.sub_format)
else:
raise NotImplementedError('unsupported format: %s' % self.sub_format)
to_write = to_write.encode(self.encoding)
if six.PY3:
with open(fname or sys.stdout.fileno(), 'wb') as f:
f.write(to_write)
else:
with (fname and open(fname, 'wb')) or sys.stdout as f:
f.write(to_write)

View file

@ -0,0 +1,374 @@
# -*- coding: utf-8 -*-
"""
This module borrows and adapts `Pipeline` from `sklearn.pipeline` and
`TransformerMixin` from `sklearn.base` in the scikit-learn framework
(commit hash d205638475ca542dc46862652e3bb0be663a8eac) to be precise).
Both are BSD licensed and allow for this sort of thing; attribution
is given as a comment above each class.
"""
from collections import defaultdict
from itertools import islice
# Author: Gael Varoquaux <gael.varoquaux@normalesup.org>
# License: BSD 3 clause
class TransformerMixin(object):
"""Mixin class for all transformers."""
def fit_transform(self, X, y=None, **fit_params):
"""
Fit to data, then transform it.
Fits transformer to X and y with optional parameters fit_params
and returns a transformed version of X.
Parameters
----------
X : ndarray of shape (n_samples, n_features)
Training set.
y : ndarray of shape (n_samples,), default=None
Target values.
**fit_params : dict
Additional fit parameters.
Returns
-------
X_new : ndarray array of shape (n_samples, n_features_new)
Transformed array.
"""
# non-optimized default implementation; override when a better
# method is possible for a given clustering algorithm
if y is None:
# fit method of arity 1 (unsupervised transformation)
return self.fit(X, **fit_params).transform(X)
else:
# fit method of arity 2 (supervised transformation)
return self.fit(X, y, **fit_params).transform(X)
# Author: Edouard Duchesnay
# Gael Varoquaux
# Virgile Fritsch
# Alexandre Gramfort
# Lars Buitinck
# License: BSD
class Pipeline(object):
def __init__(self, steps, verbose=False):
self.steps = steps
self.verbose = verbose
self._validate_steps()
def _validate_steps(self):
names, estimators = zip(*self.steps)
# validate estimators
transformers = estimators[:-1]
estimator = estimators[-1]
for t in transformers:
if t is None or t == 'passthrough':
continue
if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not
hasattr(t, "transform")):
raise TypeError("All intermediate steps should be "
"transformers and implement fit and transform "
"or be the string 'passthrough' "
"'%s' (type %s) doesn't" % (t, type(t)))
# We allow last estimator to be None as an identity transformation
if (estimator is not None and estimator != 'passthrough'
and not hasattr(estimator, "fit")):
raise TypeError(
"Last step of Pipeline should implement fit "
"or be the string 'passthrough'. "
"'%s' (type %s) doesn't" % (estimator, type(estimator)))
def _iter(self, with_final=True, filter_passthrough=True):
"""
Generate (idx, (name, trans)) tuples from self.steps
When filter_passthrough is True, 'passthrough' and None transformers
are filtered out.
"""
stop = len(self.steps)
if not with_final:
stop -= 1
for idx, (name, trans) in enumerate(islice(self.steps, 0, stop)):
if not filter_passthrough:
yield idx, name, trans
elif trans is not None and trans != 'passthrough':
yield idx, name, trans
def __len__(self):
"""
Returns the length of the Pipeline
"""
return len(self.steps)
def __getitem__(self, ind):
"""Returns a sub-pipeline or a single esimtator in the pipeline
Indexing with an integer will return an estimator; using a slice
returns another Pipeline instance which copies a slice of this
Pipeline. This copy is shallow: modifying (or fitting) estimators in
the sub-pipeline will affect the larger pipeline and vice-versa.
However, replacing a value in `step` will not affect a copy.
"""
if isinstance(ind, slice):
if ind.step not in (1, None):
raise ValueError('Pipeline slicing only supports a step of 1')
return self.__class__(self.steps[ind])
try:
name, est = self.steps[ind]
except TypeError:
# Not an int, try get step by name
return self.named_steps[ind]
return est
@property
def _estimator_type(self):
return self.steps[-1][1]._estimator_type
@property
def named_steps(self):
return dict(self.steps)
@property
def _final_estimator(self):
estimator = self.steps[-1][1]
return 'passthrough' if estimator is None else estimator
def _log_message(self, step_idx):
if not self.verbose:
return None
name, step = self.steps[step_idx]
return '(step %d of %d) Processing %s' % (step_idx + 1,
len(self.steps),
name)
# Estimator interface
def _fit(self, X, y=None, **fit_params):
# shallow copy of steps - this should really be steps_
self.steps = list(self.steps)
self._validate_steps()
fit_params_steps = {name: {} for name, step in self.steps
if step is not None}
for pname, pval in fit_params.items():
if '__' not in pname:
raise ValueError(
"Pipeline.fit does not accept the {} parameter. "
"You can pass parameters to specific steps of your "
"pipeline using the stepname__parameter format, e.g. "
"`Pipeline.fit(X, y, logisticregression__sample_weight"
"=sample_weight)`.".format(pname))
step, param = pname.split('__', 1)
fit_params_steps[step][param] = pval
for (step_idx,
name,
transformer) in self._iter(with_final=False,
filter_passthrough=False):
if transformer is None or transformer == 'passthrough':
continue
# Fit or load from cache the current transformer
X, fitted_transformer = _fit_transform_one(
transformer, X, y, None,
**fit_params_steps[name])
# Replace the transformer of the step with the fitted
# transformer. This is necessary when loading the transformer
# from the cache.
self.steps[step_idx] = (name, fitted_transformer)
if self._final_estimator == 'passthrough':
return X, {}
return X, fit_params_steps[self.steps[-1][0]]
def fit(self, X, y=None, **fit_params):
"""Fit the model
Fit all the transforms one after the other and transform the
data, then fit the transformed data using the final estimator.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of the
pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps of
the pipeline.
**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of each step, where
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.
Returns
-------
self : Pipeline
This estimator
"""
Xt, fit_params = self._fit(X, y, **fit_params)
if self._final_estimator != 'passthrough':
self._final_estimator.fit(Xt, y, **fit_params)
return self
def fit_transform(self, X, y=None, **fit_params):
"""Fit the model and transform with the final estimator
Fits all the transforms one after the other and transforms the
data, then uses fit_transform on transformed data with the final
estimator.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of the
pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps of
the pipeline.
**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of each step, where
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.
Returns
-------
Xt : array-like of shape (n_samples, n_transformed_features)
Transformed samples
"""
last_step = self._final_estimator
Xt, fit_params = self._fit(X, y, **fit_params)
if last_step == 'passthrough':
return Xt
if hasattr(last_step, 'fit_transform'):
return last_step.fit_transform(Xt, y, **fit_params)
else:
return last_step.fit(Xt, y, **fit_params).transform(Xt)
@property
def transform(self):
"""Apply transforms, and transform with the final estimator
This also works where final estimator is ``None``: all prior
transformations are applied.
Parameters
----------
X : iterable
Data to transform. Must fulfill input requirements of first step
of the pipeline.
Returns
-------
Xt : array-like of shape (n_samples, n_transformed_features)
"""
# _final_estimator is None or has transform, otherwise attribute error
# XXX: Handling the None case means we can't use if_delegate_has_method
if self._final_estimator != 'passthrough':
self._final_estimator.transform
return self._transform
def _transform(self, X):
Xt = X
for _, _, transform in self._iter():
Xt = transform.transform(Xt)
return Xt
@property
def classes_(self):
return self.steps[-1][-1].classes_
@property
def _pairwise(self):
# check if first estimator expects pairwise input
return getattr(self.steps[0][1], '_pairwise', False)
@property
def n_features_in_(self):
# delegate to first step (which will call _check_is_fitted)
return self.steps[0][1].n_features_in_
def _name_estimators(estimators):
"""Generate names for estimators."""
names = [
estimator
if isinstance(estimator, str) else type(estimator).__name__.lower()
for estimator in estimators
]
namecount = defaultdict(int)
for est, name in zip(estimators, names):
namecount[name] += 1
for k, v in list(namecount.items()):
if v == 1:
del namecount[k]
for i in reversed(range(len(estimators))):
name = names[i]
if name in namecount:
names[i] += "-%d" % namecount[name]
namecount[name] -= 1
return list(zip(names, estimators))
def make_pipeline(*steps, **kwargs):
"""Construct a Pipeline from the given estimators.
This is a shorthand for the Pipeline constructor; it does not require, and
does not permit, naming the estimators. Instead, their names will be set
to the lowercase of their types automatically.
Parameters
----------
*steps : list of estimators.
verbose : bool, default=False
If True, the time elapsed while fitting each step will be printed as it
is completed.
Returns
-------
p : Pipeline
"""
verbose = kwargs.pop('verbose', False)
if kwargs:
raise TypeError('Unknown keyword arguments: "{}"'
.format(list(kwargs.keys())[0]))
return Pipeline(_name_estimators(steps), verbose=verbose)
def _transform_one(transformer, X, y, weight, **fit_params):
res = transformer.transform(X)
# if we have a weight for this transformer, multiply output
if weight is None:
return res
return res * weight
def _fit_transform_one(transformer,
X,
y,
weight,
**fit_params):
"""
Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned
with the fitted transformer. If ``weight`` is not ``None``, the result will
be multiplied by ``weight``.
"""
if hasattr(transformer, 'fit_transform'):
res = transformer.fit_transform(X, y, **fit_params)
else:
res = transformer.fit(X, y, **fit_params).transform(X)
if weight is None:
return res, transformer
return res * weight, transformer

View file

@ -0,0 +1,368 @@
# -*- coding: utf-8 -*-
from contextlib import contextmanager
import logging
import io
import os
import platform
import subprocess
import sys
from datetime import timedelta
import ffmpeg
import numpy as np
from .sklearn_shim import TransformerMixin
from .sklearn_shim import Pipeline
import tqdm
from .constants import *
from .subtitle_parser import make_subtitle_parser
from .subtitle_transformers import SubtitleScaler
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ref: https://github.com/pyinstaller/pyinstaller/wiki/Recipe-subprocess
# Create a set of arguments which make a ``subprocess.Popen`` (and
# variants) call work with or without Pyinstaller, ``--noconsole`` or
# not, on Windows and Linux. Typical use::
#
# subprocess.call(['program_to_run', 'arg_1'], **subprocess_args())
#
# When calling ``check_output``::
#
# subprocess.check_output(['program_to_run', 'arg_1'],
# **subprocess_args(False))
def _subprocess_args(include_stdout=True):
# The following is true only on Windows.
if hasattr(subprocess, 'STARTUPINFO'):
# On Windows, subprocess calls will pop up a command window by default
# when run from Pyinstaller with the ``--noconsole`` option. Avoid this
# distraction.
si = subprocess.STARTUPINFO()
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
# Windows doesn't search the path by default. Pass it an environment so
# it will.
env = os.environ
else:
si = None
env = None
# ``subprocess.check_output`` doesn't allow specifying ``stdout``::
#
# Traceback (most recent call last):
# File "test_subprocess.py", line 58, in <module>
# **subprocess_args(stdout=None))
# File "C:\Python27\lib\subprocess.py", line 567, in check_output
# raise ValueError('stdout argument not allowed, it will be overridden.')
# ValueError: stdout argument not allowed, it will be overridden.
#
# So, add it only if it's needed.
if include_stdout:
ret = {'stdout': subprocess.PIPE}
else:
ret = {}
# On Windows, running this from the binary produced by Pyinstaller
# with the ``--noconsole`` option requires redirecting everything
# (stdin, stdout, stderr) to avoid an OSError exception
# "[Error 6] the handle is invalid."
ret.update({'stdin': subprocess.PIPE,
'stderr': subprocess.PIPE,
'startupinfo': si,
'env': env})
return ret
def _ffmpeg_bin_path(bin_name, gui_mode, ffmpeg_resources_path=None):
if platform.system() == 'Windows':
bin_name = '{}.exe'.format(bin_name)
if ffmpeg_resources_path is not None:
return os.path.join(ffmpeg_resources_path, bin_name)
try:
resource_path = os.environ[SUBSYNC_RESOURCES_ENV_MAGIC]
if len(resource_path) > 0:
return os.path.join(resource_path, 'ffmpeg-bin', bin_name)
except KeyError as e:
if gui_mode:
logger.info("Couldn't find resource path; falling back to searching system path")
return bin_name
def make_subtitle_speech_pipeline(
fmt='srt',
encoding=DEFAULT_ENCODING,
caching=False,
max_subtitle_seconds=DEFAULT_MAX_SUBTITLE_SECONDS,
start_seconds=DEFAULT_START_SECONDS,
scale_factor=DEFAULT_SCALE_FACTOR,
parser=None,
**kwargs
):
if parser is None:
parser = make_subtitle_parser(
fmt,
encoding=encoding,
caching=caching,
max_subtitle_seconds=max_subtitle_seconds,
start_seconds=start_seconds
)
assert parser.encoding == encoding
assert parser.max_subtitle_seconds == max_subtitle_seconds
assert parser.start_seconds == start_seconds
return Pipeline([
('parse', parser),
('scale', SubtitleScaler(scale_factor)),
('speech_extract', SubtitleSpeechTransformer(
sample_rate=SAMPLE_RATE,
start_seconds=start_seconds,
framerate_ratio=scale_factor,
))
])
def _make_auditok_detector(sample_rate, frame_rate):
try:
from auditok import \
BufferAudioSource, ADSFactory, AudioEnergyValidator, StreamTokenizer
except ImportError as e:
logger.error("""Error: auditok not installed!
Consider installing it with `pip install auditok`. Note that auditok
is GPLv3 licensed, which means that successfully importing it at
runtime creates a derivative work that is GPLv3 licensed. For personal
use this is fine, but note that any commercial use that relies on
auditok must be open source as per the GPLv3!*
*Not legal advice. Consult with a lawyer.
""")
raise e
bytes_per_frame = 2
frames_per_window = frame_rate // sample_rate
validator = AudioEnergyValidator(
sample_width=bytes_per_frame, energy_threshold=50)
tokenizer = StreamTokenizer(
validator=validator, min_length=0.2*sample_rate,
max_length=int(5*sample_rate),
max_continuous_silence=0.25*sample_rate)
def _detect(asegment):
asource = BufferAudioSource(data_buffer=asegment,
sampling_rate=frame_rate,
sample_width=bytes_per_frame,
channels=1)
ads = ADSFactory.ads(audio_source=asource, block_dur=1./sample_rate)
ads.open()
tokens = tokenizer.tokenize(ads)
length = (len(asegment)//bytes_per_frame
+ frames_per_window - 1)//frames_per_window
media_bstring = np.zeros(length+1, dtype=int)
for token in tokens:
media_bstring[token[1]] += 1
media_bstring[token[2]+1] -= 1
return (np.cumsum(media_bstring)[:-1] > 0).astype(float)
return _detect
def _make_webrtcvad_detector(sample_rate, frame_rate):
import webrtcvad
vad = webrtcvad.Vad()
vad.set_mode(3) # set non-speech pruning aggressiveness from 0 to 3
window_duration = 1. / sample_rate # duration in seconds
frames_per_window = int(window_duration * frame_rate + 0.5)
bytes_per_frame = 2
def _detect(asegment):
media_bstring = []
failures = 0
for start in range(0, len(asegment) // bytes_per_frame,
frames_per_window):
stop = min(start + frames_per_window,
len(asegment) // bytes_per_frame)
try:
is_speech = vad.is_speech(
asegment[start * bytes_per_frame: stop * bytes_per_frame],
sample_rate=frame_rate)
except:
is_speech = False
failures += 1
# webrtcvad has low recall on mode 3, so treat non-speech as "not sure"
media_bstring.append(1. if is_speech else 0.5)
return np.array(media_bstring)
return _detect
class VideoSpeechTransformer(TransformerMixin):
def __init__(self, vad, sample_rate, frame_rate, start_seconds=0, ffmpeg_path=None, ref_stream=None, vlc_mode=False, gui_mode=False):
self.vad = vad
self.sample_rate = sample_rate
self.frame_rate = frame_rate
self.start_seconds = start_seconds
self.ffmpeg_path = ffmpeg_path
self.ref_stream = ref_stream
self.vlc_mode = vlc_mode
self.gui_mode = gui_mode
self.video_speech_results_ = None
def try_fit_using_embedded_subs(self, fname):
embedded_subs = []
embedded_subs_times = []
if self.ref_stream is None:
# check first 5; should cover 99% of movies
streams_to_try = map('0:s:{}'.format, range(5))
else:
streams_to_try = [self.ref_stream]
for stream in streams_to_try:
ffmpeg_args = [_ffmpeg_bin_path('ffmpeg', self.gui_mode, ffmpeg_resources_path=self.ffmpeg_path)]
ffmpeg_args.extend([
'-loglevel', 'fatal',
'-nostdin',
'-i', fname,
'-map', '{}'.format(stream),
'-f', 'srt',
'-'
])
process = subprocess.Popen(ffmpeg_args, **_subprocess_args(include_stdout=True))
output = io.BytesIO(process.communicate()[0])
if process.returncode != 0:
break
pipe = make_subtitle_speech_pipeline(start_seconds=self.start_seconds).fit(output)
speech_step = pipe.steps[-1][1]
embedded_subs.append(speech_step.subtitle_speech_results_)
embedded_subs_times.append(speech_step.max_time_)
if len(embedded_subs) == 0:
raise ValueError('Video file appears to lack subtitle stream')
# use longest set of embedded subs
self.video_speech_results_ = embedded_subs[int(np.argmax(embedded_subs_times))]
def fit(self, fname, *_):
if 'subs' in self.vad and (self.ref_stream is None or self.ref_stream.startswith('0:s:')):
try:
logger.info('Checking video for subtitles stream...')
self.try_fit_using_embedded_subs(fname)
logger.info('...success!')
return self
except Exception as e:
logger.info(e)
try:
total_duration = float(ffmpeg.probe(
fname, cmd=_ffmpeg_bin_path('ffprobe', self.gui_mode, ffmpeg_resources_path=self.ffmpeg_path)
)['format']['duration']) - self.start_seconds
except Exception as e:
logger.warning(e)
total_duration = None
if 'webrtc' in self.vad:
detector = _make_webrtcvad_detector(self.sample_rate, self.frame_rate)
elif 'auditok' in self.vad:
detector = _make_auditok_detector(self.sample_rate, self.frame_rate)
else:
raise ValueError('unknown vad: %s' % self.vad)
media_bstring = []
ffmpeg_args = [_ffmpeg_bin_path('ffmpeg', self.gui_mode, ffmpeg_resources_path=self.ffmpeg_path)]
if self.start_seconds > 0:
ffmpeg_args.extend([
'-ss', str(timedelta(seconds=self.start_seconds)),
])
ffmpeg_args.extend([
'-loglevel', 'fatal',
'-nostdin',
'-i', fname
])
if self.ref_stream is not None and self.ref_stream.startswith('0:a:'):
ffmpeg_args.extend(['-map', self.ref_stream])
ffmpeg_args.extend([
'-f', 's16le',
'-ac', '1',
'-acodec', 'pcm_s16le',
'-ar', str(self.frame_rate),
'-'
])
process = subprocess.Popen(ffmpeg_args, **_subprocess_args(include_stdout=True))
bytes_per_frame = 2
frames_per_window = bytes_per_frame * self.frame_rate // self.sample_rate
windows_per_buffer = 10000
simple_progress = 0.
@contextmanager
def redirect_stderr(enter_result=None):
yield enter_result
tqdm_extra_args = {}
should_print_redirected_stderr = self.gui_mode
if self.gui_mode:
try:
from contextlib import redirect_stderr
tqdm_extra_args['file'] = sys.stdout
except ImportError:
should_print_redirected_stderr = False
pbar_output = io.StringIO()
with redirect_stderr(pbar_output):
with tqdm.tqdm(total=total_duration, disable=self.vlc_mode, **tqdm_extra_args) as pbar:
while True:
in_bytes = process.stdout.read(frames_per_window * windows_per_buffer)
if not in_bytes:
break
newstuff = len(in_bytes) / float(bytes_per_frame) / self.frame_rate
simple_progress += newstuff
pbar.update(newstuff)
if self.vlc_mode and total_duration is not None:
print("%d" % int(simple_progress * 100. / total_duration))
sys.stdout.flush()
if should_print_redirected_stderr:
assert self.gui_mode
# no need to flush since we pass -u to do unbuffered output for gui mode
print(pbar_output.read())
in_bytes = np.frombuffer(in_bytes, np.uint8)
media_bstring.append(detector(in_bytes))
if len(media_bstring) == 0:
raise ValueError(
'Unable to detect speech. Perhaps try specifying a different stream / track, or a different vad.'
)
self.video_speech_results_ = np.concatenate(media_bstring)
return self
def transform(self, *_):
return self.video_speech_results_
class SubtitleSpeechTransformer(TransformerMixin):
def __init__(self, sample_rate, start_seconds=0, framerate_ratio=1.):
self.sample_rate = sample_rate
self.start_seconds = start_seconds
self.framerate_ratio = framerate_ratio
self.subtitle_speech_results_ = None
self.max_time_ = None
def fit(self, subs, *_):
max_time = 0
for sub in subs:
max_time = max(max_time, sub.end.total_seconds())
self.max_time_ = max_time - self.start_seconds
samples = np.zeros(int(max_time * self.sample_rate) + 2, dtype=float)
for sub in subs:
start = int(round((sub.start.total_seconds() - self.start_seconds) * self.sample_rate))
duration = sub.end.total_seconds() - sub.start.total_seconds()
end = start + int(round(duration * self.sample_rate))
samples[start:end] = min(1. / self.framerate_ratio, 1.)
self.subtitle_speech_results_ = samples
return self
def transform(self, *_):
return self.subtitle_speech_results_
class DeserializeSpeechTransformer(TransformerMixin):
def __init__(self):
self.deserialized_speech_results_ = None
def fit(self, fname, *_):
speech = np.load(fname)
if hasattr(speech, 'files'):
if 'speech' in speech.files:
speech = speech['speech']
else:
raise ValueError('could not find "speech" array in '
'serialized file; only contains: %s' % speech.files)
self.deserialized_speech_results_ = speech
return self
def transform(self, *_):
return self.deserialized_speech_results_

View file

@ -0,0 +1,27 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import sys
from sklearn.pipeline import Pipeline
from .subtitle_parser import GenericSubtitleParser
from .subtitle_transformers import SubtitleShifter
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
td = float(sys.argv[3])
pipe = Pipeline([
('parse', GenericSubtitleParser()),
('offset', SubtitleShifter(td)),
])
pipe.fit_transform(sys.argv[1])
pipe.steps[-1][1].write_file(sys.argv[2])
return 0
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
from datetime import timedelta
import logging
import chardet
import pysubs2
from .sklearn_shim import TransformerMixin
import srt
from .constants import *
from .file_utils import open_file
from .generic_subtitles import GenericSubtitle, GenericSubtitlesFile, SubsMixin
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def make_subtitle_parser(
fmt,
encoding=DEFAULT_ENCODING,
caching=False,
max_subtitle_seconds=DEFAULT_MAX_SUBTITLE_SECONDS,
start_seconds=DEFAULT_START_SECONDS,
**kwargs
):
return GenericSubtitleParser(
fmt=fmt,
encoding=encoding,
caching=caching,
max_subtitle_seconds=max_subtitle_seconds,
start_seconds=start_seconds
)
def _preprocess_subs(subs, max_subtitle_seconds=None, start_seconds=0, tolerant=True):
subs_list = []
start_time = timedelta(seconds=start_seconds)
max_duration = timedelta(days=1)
if max_subtitle_seconds is not None:
max_duration = timedelta(seconds=max_subtitle_seconds)
subs = iter(subs)
while True:
try:
next_sub = GenericSubtitle.wrap_inner_subtitle(next(subs))
if next_sub.start < start_time:
continue
next_sub.end = min(next_sub.end, next_sub.start + max_duration)
subs_list.append(next_sub)
# We don't catch SRTParseError here b/c that is typically raised when we
# are trying to parse with the wrong encoding, in which case we might
# be able to try another one on the *entire* set of subtitles elsewhere.
except ValueError as e:
if tolerant:
logger.warning(e)
continue
else:
raise
except StopIteration:
break
return subs_list
class GenericSubtitleParser(SubsMixin, TransformerMixin):
def __init__(self, fmt='srt', encoding='infer', caching=False, max_subtitle_seconds=None, start_seconds=0):
super(self.__class__, self).__init__()
self.sub_format = fmt
self.encoding = encoding
self.caching = caching
self.fit_fname = None
self.detected_encoding_ = None
self.sub_skippers = []
self.max_subtitle_seconds = max_subtitle_seconds
self.start_seconds = start_seconds
def fit(self, fname, *_):
if self.caching and self.fit_fname == fname:
return self
encodings_to_try = (self.encoding,)
with open_file(fname, 'rb') as f:
subs = f.read()
if self.encoding == 'infer':
encodings_to_try = (chardet.detect(subs)['encoding'],)
exc = None
for encoding in encodings_to_try:
try:
decoded_subs = subs.decode(encoding, errors='replace').strip()
if self.sub_format == 'srt':
parsed_subs = srt.parse(decoded_subs)
elif self.sub_format in ('ass', 'ssa'):
parsed_subs = pysubs2.SSAFile.from_string(decoded_subs)
else:
raise NotImplementedError('unsupported format: %s' % self.sub_format)
self.subs_ = GenericSubtitlesFile(
_preprocess_subs(parsed_subs,
max_subtitle_seconds=self.max_subtitle_seconds,
start_seconds=self.start_seconds),
sub_format=self.sub_format,
encoding=encoding
)
self.fit_fname = fname
self.detected_encoding_ = encoding
logger.info('detected encoding: %s' % self.detected_encoding_)
return self
except Exception as e:
exc = e
continue
raise exc
def transform(self, *_):
return self.subs_

View file

@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
from datetime import timedelta
import logging
import numbers
from .sklearn_shim import TransformerMixin
from .generic_subtitles import GenericSubtitle, GenericSubtitlesFile, SubsMixin
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SubtitleShifter(SubsMixin, TransformerMixin):
def __init__(self, td_seconds):
super(SubsMixin, self).__init__()
if not isinstance(td_seconds, timedelta):
self.td_seconds = timedelta(seconds=td_seconds)
else:
self.td_seconds = td_seconds
def fit(self, subs, *_):
self.subs_ = subs.offset(self.td_seconds)
return self
def transform(self, *_):
return self.subs_
class SubtitleScaler(SubsMixin, TransformerMixin):
def __init__(self, scale_factor):
assert isinstance(scale_factor, numbers.Number)
super(SubsMixin, self).__init__()
self.scale_factor = scale_factor
def fit(self, subs, *_):
scaled_subs = []
for sub in subs:
scaled_subs.append(
GenericSubtitle(
# py2 doesn't support direct multiplication of timedelta w/ float
timedelta(seconds=sub.start.total_seconds() * self.scale_factor),
timedelta(seconds=sub.end.total_seconds() * self.scale_factor),
sub.inner
)
)
self.subs_ = GenericSubtitlesFile(scaled_subs, sub_format=subs.sub_format, encoding=subs.encoding)
return self
def transform(self, *_):
return self.subs_
class SubtitleMerger(SubsMixin, TransformerMixin):
def __init__(self, reference_subs, first='reference'):
assert first in ('reference', 'output')
super(SubsMixin, self).__init__()
self.reference_subs = reference_subs
self.first = first
def fit(self, output_subs, *_):
def _merger_gen(a, b):
ita, itb = iter(a), iter(b)
cur_a = next(ita, None)
cur_b = next(itb, None)
while True:
if cur_a is None and cur_b is None:
return
elif cur_a is None:
while cur_b is not None:
yield cur_b
cur_b = next(itb, None)
return
elif cur_b is None:
while cur_a is not None:
yield cur_a
cur_a = next(ita, None)
return
# else: neither are None
if cur_a.start < cur_b.start:
swapped = False
else:
swapped = True
cur_a, cur_b = cur_b, cur_a
ita, itb = itb, ita
prev_a = cur_a
while prev_a is not None and cur_a.start < cur_b.start:
cur_a = next(ita, None)
if cur_a is None or cur_a.start < cur_b.start:
yield prev_a
prev_a = cur_a
if prev_a is None:
while cur_b is not None:
yield cur_b
cur_b = next(itb, None)
return
if cur_b.start - prev_a.start < cur_a.start - cur_b.start:
if swapped:
yield cur_b.merge_with(prev_a)
ita, itb = itb, ita
cur_a, cur_b = cur_b, cur_a
cur_a = next(ita, None)
else:
yield prev_a.merge_with(cur_b)
cur_b = next(itb, None)
else:
if swapped:
yield cur_b.merge_with(cur_a)
ita, itb = itb, ita
else:
yield cur_a.merge_with(cur_b)
cur_a = next(ita, None)
cur_b = next(itb, None)
merged_subs = []
if self.first == 'reference':
first, second = self.reference_subs, output_subs
else:
first, second = output_subs, self.reference_subs
for merged in _merger_gen(first, second):
merged_subs.append(merged)
self.subs_ = GenericSubtitlesFile(
merged_subs,
sub_format=output_subs.sub_format,
encoding=output_subs.encoding
)
return self
def transform(self, *_):
return self.subs_

24
libs/ffsubsync/version.py Normal file
View file

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
__version__ = '0.4.3'
def make_version_tuple(vstr):
if vstr[0] == 'v':
vstr = vstr[1:]
return tuple(map(int, vstr.split('.')))
def update_available():
import requests
from requests.exceptions import Timeout
from .constants import API_RELEASE_URL
try:
resp = requests.get(API_RELEASE_URL, timeout=1)
latest_vstr = resp.json()['tag_name']
except Timeout:
return False
except KeyError:
return False
if not resp.ok:
return False
return make_version_tuple(__version__) < make_version_tuple(latest_vstr)

93
libs/future/__init__.py Normal file
View file

@ -0,0 +1,93 @@
"""
future: Easy, safe support for Python 2/3 compatibility
=======================================================
``future`` is the missing compatibility layer between Python 2 and Python
3. It allows you to use a single, clean Python 3.x-compatible codebase to
support both Python 2 and Python 3 with minimal overhead.
It is designed to be used as follows::
from __future__ import (absolute_import, division,
print_function, unicode_literals)
from builtins import (
bytes, dict, int, list, object, range, str,
ascii, chr, hex, input, next, oct, open,
pow, round, super,
filter, map, zip)
followed by predominantly standard, idiomatic Python 3 code that then runs
similarly on Python 2.6/2.7 and Python 3.3+.
The imports have no effect on Python 3. On Python 2, they shadow the
corresponding builtins, which normally have different semantics on Python 3
versus 2, to provide their Python 3 semantics.
Standard library reorganization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``future`` supports the standard library reorganization (PEP 3108) through the
following Py3 interfaces:
>>> # Top-level packages with Py3 names provided on Py2:
>>> import html.parser
>>> import queue
>>> import tkinter.dialog
>>> import xmlrpc.client
>>> # etc.
>>> # Aliases provided for extensions to existing Py2 module names:
>>> from future.standard_library import install_aliases
>>> install_aliases()
>>> from collections import Counter, OrderedDict # backported to Py2.6
>>> from collections import UserDict, UserList, UserString
>>> import urllib.request
>>> from itertools import filterfalse, zip_longest
>>> from subprocess import getoutput, getstatusoutput
Automatic conversion
--------------------
An included script called `futurize
<http://python-future.org/automatic_conversion.html>`_ aids in converting
code (from either Python 2 or Python 3) to code compatible with both
platforms. It is similar to ``python-modernize`` but goes further in
providing Python 3 compatibility through the use of the backported types
and builtin functions in ``future``.
Documentation
-------------
See: http://python-future.org
Credits
-------
:Author: Ed Schofield
:Sponsor: Python Charmers Pty Ltd, Australia, and Python Charmers Pte
Ltd, Singapore. http://pythoncharmers.com
:Others: See docs/credits.rst or http://python-future.org/credits.html
Licensing
---------
Copyright 2013-2018 Python Charmers Pty Ltd, Australia.
The software is distributed under an MIT licence. See LICENSE.txt.
"""
__title__ = 'future'
__author__ = 'Ed Schofield'
__license__ = 'MIT'
__copyright__ = 'Copyright 2013-2018 Python Charmers Pty Ltd'
__ver_major__ = 0
__ver_minor__ = 17
__ver_patch__ = 0
__ver_sub__ = ''
__version__ = "%d.%d.%d%s" % (__ver_major__, __ver_minor__,
__ver_patch__, __ver_sub__)

View file

@ -0,0 +1,26 @@
"""
future.backports package
"""
from __future__ import absolute_import
import sys
__future_module__ = True
from future.standard_library import import_top_level_modules
if sys.version_info[0] == 3:
import_top_level_modules()
from .misc import (ceil,
OrderedDict,
Counter,
ChainMap,
check_output,
count,
recursive_repr,
_count_elements,
cmp_to_key
)

View file

@ -0,0 +1,422 @@
"""Shared support for scanning document type declarations in HTML and XHTML.
Backported for python-future from Python 3.3. Reason: ParserBase is an
old-style class in the Python 2.7 source of markupbase.py, which I suspect
might be the cause of sporadic unit-test failures on travis-ci.org with
test_htmlparser.py. The test failures look like this:
======================================================================
ERROR: test_attr_entity_replacement (future.tests.test_htmlparser.AttributesStrictTestCase)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/travis/build/edschofield/python-future/future/tests/test_htmlparser.py", line 661, in test_attr_entity_replacement
[("starttag", "a", [("b", "&><\"'")])])
File "/home/travis/build/edschofield/python-future/future/tests/test_htmlparser.py", line 93, in _run_check
collector = self.get_collector()
File "/home/travis/build/edschofield/python-future/future/tests/test_htmlparser.py", line 617, in get_collector
return EventCollector(strict=True)
File "/home/travis/build/edschofield/python-future/future/tests/test_htmlparser.py", line 27, in __init__
html.parser.HTMLParser.__init__(self, *args, **kw)
File "/home/travis/build/edschofield/python-future/future/backports/html/parser.py", line 135, in __init__
self.reset()
File "/home/travis/build/edschofield/python-future/future/backports/html/parser.py", line 143, in reset
_markupbase.ParserBase.reset(self)
TypeError: unbound method reset() must be called with ParserBase instance as first argument (got EventCollector instance instead)
This module is used as a foundation for the html.parser module. It has no
documented public API and should not be used directly.
"""
import re
_declname_match = re.compile(r'[a-zA-Z][-_.a-zA-Z0-9]*\s*').match
_declstringlit_match = re.compile(r'(\'[^\']*\'|"[^"]*")\s*').match
_commentclose = re.compile(r'--\s*>')
_markedsectionclose = re.compile(r']\s*]\s*>')
# An analysis of the MS-Word extensions is available at
# http://www.planetpublish.com/xmlarena/xap/Thursday/WordtoXML.pdf
_msmarkedsectionclose = re.compile(r']\s*>')
del re
class ParserBase(object):
"""Parser base class which provides some common support methods used
by the SGML/HTML and XHTML parsers."""
def __init__(self):
if self.__class__ is ParserBase:
raise RuntimeError(
"_markupbase.ParserBase must be subclassed")
def error(self, message):
raise NotImplementedError(
"subclasses of ParserBase must override error()")
def reset(self):
self.lineno = 1
self.offset = 0
def getpos(self):
"""Return current line number and offset."""
return self.lineno, self.offset
# Internal -- update line number and offset. This should be
# called for each piece of data exactly once, in order -- in other
# words the concatenation of all the input strings to this
# function should be exactly the entire input.
def updatepos(self, i, j):
if i >= j:
return j
rawdata = self.rawdata
nlines = rawdata.count("\n", i, j)
if nlines:
self.lineno = self.lineno + nlines
pos = rawdata.rindex("\n", i, j) # Should not fail
self.offset = j-(pos+1)
else:
self.offset = self.offset + j-i
return j
_decl_otherchars = ''
# Internal -- parse declaration (for use by subclasses).
def parse_declaration(self, i):
# This is some sort of declaration; in "HTML as
# deployed," this should only be the document type
# declaration ("<!DOCTYPE html...>").
# ISO 8879:1986, however, has more complex
# declaration syntax for elements in <!...>, including:
# --comment--
# [marked section]
# name in the following list: ENTITY, DOCTYPE, ELEMENT,
# ATTLIST, NOTATION, SHORTREF, USEMAP,
# LINKTYPE, LINK, IDLINK, USELINK, SYSTEM
rawdata = self.rawdata
j = i + 2
assert rawdata[i:j] == "<!", "unexpected call to parse_declaration"
if rawdata[j:j+1] == ">":
# the empty comment <!>
return j + 1
if rawdata[j:j+1] in ("-", ""):
# Start of comment followed by buffer boundary,
# or just a buffer boundary.
return -1
# A simple, practical version could look like: ((name|stringlit) S*) + '>'
n = len(rawdata)
if rawdata[j:j+2] == '--': #comment
# Locate --.*-- as the body of the comment
return self.parse_comment(i)
elif rawdata[j] == '[': #marked section
# Locate [statusWord [...arbitrary SGML...]] as the body of the marked section
# Where statusWord is one of TEMP, CDATA, IGNORE, INCLUDE, RCDATA
# Note that this is extended by Microsoft Office "Save as Web" function
# to include [if...] and [endif].
return self.parse_marked_section(i)
else: #all other declaration elements
decltype, j = self._scan_name(j, i)
if j < 0:
return j
if decltype == "doctype":
self._decl_otherchars = ''
while j < n:
c = rawdata[j]
if c == ">":
# end of declaration syntax
data = rawdata[i+2:j]
if decltype == "doctype":
self.handle_decl(data)
else:
# According to the HTML5 specs sections "8.2.4.44 Bogus
# comment state" and "8.2.4.45 Markup declaration open
# state", a comment token should be emitted.
# Calling unknown_decl provides more flexibility though.
self.unknown_decl(data)
return j + 1
if c in "\"'":
m = _declstringlit_match(rawdata, j)
if not m:
return -1 # incomplete
j = m.end()
elif c in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ":
name, j = self._scan_name(j, i)
elif c in self._decl_otherchars:
j = j + 1
elif c == "[":
# this could be handled in a separate doctype parser
if decltype == "doctype":
j = self._parse_doctype_subset(j + 1, i)
elif decltype in set(["attlist", "linktype", "link", "element"]):
# must tolerate []'d groups in a content model in an element declaration
# also in data attribute specifications of attlist declaration
# also link type declaration subsets in linktype declarations
# also link attribute specification lists in link declarations
self.error("unsupported '[' char in %s declaration" % decltype)
else:
self.error("unexpected '[' char in declaration")
else:
self.error(
"unexpected %r char in declaration" % rawdata[j])
if j < 0:
return j
return -1 # incomplete
# Internal -- parse a marked section
# Override this to handle MS-word extension syntax <![if word]>content<![endif]>
def parse_marked_section(self, i, report=1):
rawdata= self.rawdata
assert rawdata[i:i+3] == '<![', "unexpected call to parse_marked_section()"
sectName, j = self._scan_name( i+3, i )
if j < 0:
return j
if sectName in set(["temp", "cdata", "ignore", "include", "rcdata"]):
# look for standard ]]> ending
match= _markedsectionclose.search(rawdata, i+3)
elif sectName in set(["if", "else", "endif"]):
# look for MS Office ]> ending
match= _msmarkedsectionclose.search(rawdata, i+3)
else:
self.error('unknown status keyword %r in marked section' % rawdata[i+3:j])
if not match:
return -1
if report:
j = match.start(0)
self.unknown_decl(rawdata[i+3: j])
return match.end(0)
# Internal -- parse comment, return length or -1 if not terminated
def parse_comment(self, i, report=1):
rawdata = self.rawdata
if rawdata[i:i+4] != '<!--':
self.error('unexpected call to parse_comment()')
match = _commentclose.search(rawdata, i+4)
if not match:
return -1
if report:
j = match.start(0)
self.handle_comment(rawdata[i+4: j])
return match.end(0)
# Internal -- scan past the internal subset in a <!DOCTYPE declaration,
# returning the index just past any whitespace following the trailing ']'.
def _parse_doctype_subset(self, i, declstartpos):
rawdata = self.rawdata
n = len(rawdata)
j = i
while j < n:
c = rawdata[j]
if c == "<":
s = rawdata[j:j+2]
if s == "<":
# end of buffer; incomplete
return -1
if s != "<!":
self.updatepos(declstartpos, j + 1)
self.error("unexpected char in internal subset (in %r)" % s)
if (j + 2) == n:
# end of buffer; incomplete
return -1
if (j + 4) > n:
# end of buffer; incomplete
return -1
if rawdata[j:j+4] == "<!--":
j = self.parse_comment(j, report=0)
if j < 0:
return j
continue
name, j = self._scan_name(j + 2, declstartpos)
if j == -1:
return -1
if name not in set(["attlist", "element", "entity", "notation"]):
self.updatepos(declstartpos, j + 2)
self.error(
"unknown declaration %r in internal subset" % name)
# handle the individual names
meth = getattr(self, "_parse_doctype_" + name)
j = meth(j, declstartpos)
if j < 0:
return j
elif c == "%":
# parameter entity reference
if (j + 1) == n:
# end of buffer; incomplete
return -1
s, j = self._scan_name(j + 1, declstartpos)
if j < 0:
return j
if rawdata[j] == ";":
j = j + 1
elif c == "]":
j = j + 1
while j < n and rawdata[j].isspace():
j = j + 1
if j < n:
if rawdata[j] == ">":
return j
self.updatepos(declstartpos, j)
self.error("unexpected char after internal subset")
else:
return -1
elif c.isspace():
j = j + 1
else:
self.updatepos(declstartpos, j)
self.error("unexpected char %r in internal subset" % c)
# end of buffer reached
return -1
# Internal -- scan past <!ELEMENT declarations
def _parse_doctype_element(self, i, declstartpos):
name, j = self._scan_name(i, declstartpos)
if j == -1:
return -1
# style content model; just skip until '>'
rawdata = self.rawdata
if '>' in rawdata[j:]:
return rawdata.find(">", j) + 1
return -1
# Internal -- scan past <!ATTLIST declarations
def _parse_doctype_attlist(self, i, declstartpos):
rawdata = self.rawdata
name, j = self._scan_name(i, declstartpos)
c = rawdata[j:j+1]
if c == "":
return -1
if c == ">":
return j + 1
while 1:
# scan a series of attribute descriptions; simplified:
# name type [value] [#constraint]
name, j = self._scan_name(j, declstartpos)
if j < 0:
return j
c = rawdata[j:j+1]
if c == "":
return -1
if c == "(":
# an enumerated type; look for ')'
if ")" in rawdata[j:]:
j = rawdata.find(")", j) + 1
else:
return -1
while rawdata[j:j+1].isspace():
j = j + 1
if not rawdata[j:]:
# end of buffer, incomplete
return -1
else:
name, j = self._scan_name(j, declstartpos)
c = rawdata[j:j+1]
if not c:
return -1
if c in "'\"":
m = _declstringlit_match(rawdata, j)
if m:
j = m.end()
else:
return -1
c = rawdata[j:j+1]
if not c:
return -1
if c == "#":
if rawdata[j:] == "#":
# end of buffer
return -1
name, j = self._scan_name(j + 1, declstartpos)
if j < 0:
return j
c = rawdata[j:j+1]
if not c:
return -1
if c == '>':
# all done
return j + 1
# Internal -- scan past <!NOTATION declarations
def _parse_doctype_notation(self, i, declstartpos):
name, j = self._scan_name(i, declstartpos)
if j < 0:
return j
rawdata = self.rawdata
while 1:
c = rawdata[j:j+1]
if not c:
# end of buffer; incomplete
return -1
if c == '>':
return j + 1
if c in "'\"":
m = _declstringlit_match(rawdata, j)
if not m:
return -1
j = m.end()
else:
name, j = self._scan_name(j, declstartpos)
if j < 0:
return j
# Internal -- scan past <!ENTITY declarations
def _parse_doctype_entity(self, i, declstartpos):
rawdata = self.rawdata
if rawdata[i:i+1] == "%":
j = i + 1
while 1:
c = rawdata[j:j+1]
if not c:
return -1
if c.isspace():
j = j + 1
else:
break
else:
j = i
name, j = self._scan_name(j, declstartpos)
if j < 0:
return j
while 1:
c = self.rawdata[j:j+1]
if not c:
return -1
if c in "'\"":
m = _declstringlit_match(rawdata, j)
if m:
j = m.end()
else:
return -1 # incomplete
elif c == ">":
return j + 1
else:
name, j = self._scan_name(j, declstartpos)
if j < 0:
return j
# Internal -- scan a name token and the new position and the token, or
# return -1 if we've reached the end of the buffer.
def _scan_name(self, i, declstartpos):
rawdata = self.rawdata
n = len(rawdata)
if i == n:
return None, -1
m = _declname_match(rawdata, i)
if m:
s = m.group()
name = s.strip()
if (i + len(s)) == n:
return None, -1 # end of buffer
return name.lower(), m.end()
else:
self.updatepos(declstartpos, i)
self.error("expected name token at %r"
% rawdata[declstartpos:declstartpos+20])
# To be overridden -- handlers for unknown objects
def unknown_decl(self, data):
pass

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,78 @@
# Copyright (C) 2001-2007 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""
Backport of the Python 3.3 email package for Python-Future.
A package for parsing, handling, and generating email messages.
"""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
# Install the surrogate escape handler here because this is used by many
# modules in the email package.
from future.utils import surrogateescape
surrogateescape.register_surrogateescape()
# (Should this be done globally by ``future``?)
__version__ = '5.1.0'
__all__ = [
'base64mime',
'charset',
'encoders',
'errors',
'feedparser',
'generator',
'header',
'iterators',
'message',
'message_from_file',
'message_from_binary_file',
'message_from_string',
'message_from_bytes',
'mime',
'parser',
'quoprimime',
'utils',
]
# Some convenience routines. Don't import Parser and Message as side-effects
# of importing email since those cascadingly import most of the rest of the
# email package.
def message_from_string(s, *args, **kws):
"""Parse a string into a Message object model.
Optional _class and strict are passed to the Parser constructor.
"""
from future.backports.email.parser import Parser
return Parser(*args, **kws).parsestr(s)
def message_from_bytes(s, *args, **kws):
"""Parse a bytes string into a Message object model.
Optional _class and strict are passed to the Parser constructor.
"""
from future.backports.email.parser import BytesParser
return BytesParser(*args, **kws).parsebytes(s)
def message_from_file(fp, *args, **kws):
"""Read a file and parse its contents into a Message object model.
Optional _class and strict are passed to the Parser constructor.
"""
from future.backports.email.parser import Parser
return Parser(*args, **kws).parse(fp)
def message_from_binary_file(fp, *args, **kws):
"""Read a binary file and parse its contents into a Message object model.
Optional _class and strict are passed to the Parser constructor.
"""
from future.backports.email.parser import BytesParser
return BytesParser(*args, **kws).parse(fp)

View file

@ -0,0 +1,232 @@
""" Routines for manipulating RFC2047 encoded words.
This is currently a package-private API, but will be considered for promotion
to a public API if there is demand.
"""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import bytes
from future.builtins import chr
from future.builtins import int
from future.builtins import str
# An ecoded word looks like this:
#
# =?charset[*lang]?cte?encoded_string?=
#
# for more information about charset see the charset module. Here it is one
# of the preferred MIME charset names (hopefully; you never know when parsing).
# cte (Content Transfer Encoding) is either 'q' or 'b' (ignoring case). In
# theory other letters could be used for other encodings, but in practice this
# (almost?) never happens. There could be a public API for adding entries
# to the CTE tables, but YAGNI for now. 'q' is Quoted Printable, 'b' is
# Base64. The meaning of encoded_string should be obvious. 'lang' is optional
# as indicated by the brackets (they are not part of the syntax) but is almost
# never encountered in practice.
#
# The general interface for a CTE decoder is that it takes the encoded_string
# as its argument, and returns a tuple (cte_decoded_string, defects). The
# cte_decoded_string is the original binary that was encoded using the
# specified cte. 'defects' is a list of MessageDefect instances indicating any
# problems encountered during conversion. 'charset' and 'lang' are the
# corresponding strings extracted from the EW, case preserved.
#
# The general interface for a CTE encoder is that it takes a binary sequence
# as input and returns the cte_encoded_string, which is an ascii-only string.
#
# Each decoder must also supply a length function that takes the binary
# sequence as its argument and returns the length of the resulting encoded
# string.
#
# The main API functions for the module are decode, which calls the decoder
# referenced by the cte specifier, and encode, which adds the appropriate
# RFC 2047 "chrome" to the encoded string, and can optionally automatically
# select the shortest possible encoding. See their docstrings below for
# details.
import re
import base64
import binascii
import functools
from string import ascii_letters, digits
from future.backports.email import errors
__all__ = ['decode_q',
'encode_q',
'decode_b',
'encode_b',
'len_q',
'len_b',
'decode',
'encode',
]
#
# Quoted Printable
#
# regex based decoder.
_q_byte_subber = functools.partial(re.compile(br'=([a-fA-F0-9]{2})').sub,
lambda m: bytes([int(m.group(1), 16)]))
def decode_q(encoded):
encoded = bytes(encoded.replace(b'_', b' '))
return _q_byte_subber(encoded), []
# dict mapping bytes to their encoded form
class _QByteMap(dict):
safe = bytes(b'-!*+/' + ascii_letters.encode('ascii') + digits.encode('ascii'))
def __missing__(self, key):
if key in self.safe:
self[key] = chr(key)
else:
self[key] = "={:02X}".format(key)
return self[key]
_q_byte_map = _QByteMap()
# In headers spaces are mapped to '_'.
_q_byte_map[ord(' ')] = '_'
def encode_q(bstring):
return str(''.join(_q_byte_map[x] for x in bytes(bstring)))
def len_q(bstring):
return sum(len(_q_byte_map[x]) for x in bytes(bstring))
#
# Base64
#
def decode_b(encoded):
defects = []
pad_err = len(encoded) % 4
if pad_err:
defects.append(errors.InvalidBase64PaddingDefect())
padded_encoded = encoded + b'==='[:4-pad_err]
else:
padded_encoded = encoded
try:
# The validate kwarg to b64decode is not supported in Py2.x
if not re.match(b'^[A-Za-z0-9+/]*={0,2}$', padded_encoded):
raise binascii.Error('Non-base64 digit found')
return base64.b64decode(padded_encoded), defects
except binascii.Error:
# Since we had correct padding, this must an invalid char error.
defects = [errors.InvalidBase64CharactersDefect()]
# The non-alphabet characters are ignored as far as padding
# goes, but we don't know how many there are. So we'll just
# try various padding lengths until something works.
for i in 0, 1, 2, 3:
try:
return base64.b64decode(encoded+b'='*i), defects
except (binascii.Error, TypeError): # Py2 raises a TypeError
if i==0:
defects.append(errors.InvalidBase64PaddingDefect())
else:
# This should never happen.
raise AssertionError("unexpected binascii.Error")
def encode_b(bstring):
return base64.b64encode(bstring).decode('ascii')
def len_b(bstring):
groups_of_3, leftover = divmod(len(bstring), 3)
# 4 bytes out for each 3 bytes (or nonzero fraction thereof) in.
return groups_of_3 * 4 + (4 if leftover else 0)
_cte_decoders = {
'q': decode_q,
'b': decode_b,
}
def decode(ew):
"""Decode encoded word and return (string, charset, lang, defects) tuple.
An RFC 2047/2243 encoded word has the form:
=?charset*lang?cte?encoded_string?=
where '*lang' may be omitted but the other parts may not be.
This function expects exactly such a string (that is, it does not check the
syntax and may raise errors if the string is not well formed), and returns
the encoded_string decoded first from its Content Transfer Encoding and
then from the resulting bytes into unicode using the specified charset. If
the cte-decoded string does not successfully decode using the specified
character set, a defect is added to the defects list and the unknown octets
are replaced by the unicode 'unknown' character \uFDFF.
The specified charset and language are returned. The default for language,
which is rarely if ever encountered, is the empty string.
"""
_, charset, cte, cte_string, _ = str(ew).split('?')
charset, _, lang = charset.partition('*')
cte = cte.lower()
# Recover the original bytes and do CTE decoding.
bstring = cte_string.encode('ascii', 'surrogateescape')
bstring, defects = _cte_decoders[cte](bstring)
# Turn the CTE decoded bytes into unicode.
try:
string = bstring.decode(charset)
except UnicodeError:
defects.append(errors.UndecodableBytesDefect("Encoded word "
"contains bytes not decodable using {} charset".format(charset)))
string = bstring.decode(charset, 'surrogateescape')
except LookupError:
string = bstring.decode('ascii', 'surrogateescape')
if charset.lower() != 'unknown-8bit':
defects.append(errors.CharsetError("Unknown charset {} "
"in encoded word; decoded as unknown bytes".format(charset)))
return string, charset, lang, defects
_cte_encoders = {
'q': encode_q,
'b': encode_b,
}
_cte_encode_length = {
'q': len_q,
'b': len_b,
}
def encode(string, charset='utf-8', encoding=None, lang=''):
"""Encode string using the CTE encoding that produces the shorter result.
Produces an RFC 2047/2243 encoded word of the form:
=?charset*lang?cte?encoded_string?=
where '*lang' is omitted unless the 'lang' parameter is given a value.
Optional argument charset (defaults to utf-8) specifies the charset to use
to encode the string to binary before CTE encoding it. Optional argument
'encoding' is the cte specifier for the encoding that should be used ('q'
or 'b'); if it is None (the default) the encoding which produces the
shortest encoded sequence is used, except that 'q' is preferred if it is up
to five characters longer. Optional argument 'lang' (default '') gives the
RFC 2243 language string to specify in the encoded word.
"""
string = str(string)
if charset == 'unknown-8bit':
bstring = string.encode('ascii', 'surrogateescape')
else:
bstring = string.encode(charset)
if encoding is None:
qlen = _cte_encode_length['q'](bstring)
blen = _cte_encode_length['b'](bstring)
# Bias toward q. 5 is arbitrary.
encoding = 'q' if qlen - blen < 5 else 'b'
encoded = _cte_encoders[encoding](bstring)
if lang:
lang = '*' + lang
return "=?{0}{1}?{2}?{3}?=".format(charset, lang, encoding, encoded)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,546 @@
# Copyright (C) 2002-2007 Python Software Foundation
# Contact: email-sig@python.org
"""Email address parsing code.
Lifted directly from rfc822.py. This should eventually be rewritten.
"""
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from future.builtins import int
__all__ = [
'mktime_tz',
'parsedate',
'parsedate_tz',
'quote',
]
import time, calendar
SPACE = ' '
EMPTYSTRING = ''
COMMASPACE = ', '
# Parse a date field
_monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul',
'aug', 'sep', 'oct', 'nov', 'dec',
'january', 'february', 'march', 'april', 'may', 'june', 'july',
'august', 'september', 'october', 'november', 'december']
_daynames = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
# The timezone table does not include the military time zones defined
# in RFC822, other than Z. According to RFC1123, the description in
# RFC822 gets the signs wrong, so we can't rely on any such time
# zones. RFC1123 recommends that numeric timezone indicators be used
# instead of timezone names.
_timezones = {'UT':0, 'UTC':0, 'GMT':0, 'Z':0,
'AST': -400, 'ADT': -300, # Atlantic (used in Canada)
'EST': -500, 'EDT': -400, # Eastern
'CST': -600, 'CDT': -500, # Central
'MST': -700, 'MDT': -600, # Mountain
'PST': -800, 'PDT': -700 # Pacific
}
def parsedate_tz(data):
"""Convert a date string to a time tuple.
Accounts for military timezones.
"""
res = _parsedate_tz(data)
if not res:
return
if res[9] is None:
res[9] = 0
return tuple(res)
def _parsedate_tz(data):
"""Convert date to extended time tuple.
The last (additional) element is the time zone offset in seconds, except if
the timezone was specified as -0000. In that case the last element is
None. This indicates a UTC timestamp that explicitly declaims knowledge of
the source timezone, as opposed to a +0000 timestamp that indicates the
source timezone really was UTC.
"""
if not data:
return
data = data.split()
# The FWS after the comma after the day-of-week is optional, so search and
# adjust for this.
if data[0].endswith(',') or data[0].lower() in _daynames:
# There's a dayname here. Skip it
del data[0]
else:
i = data[0].rfind(',')
if i >= 0:
data[0] = data[0][i+1:]
if len(data) == 3: # RFC 850 date, deprecated
stuff = data[0].split('-')
if len(stuff) == 3:
data = stuff + data[1:]
if len(data) == 4:
s = data[3]
i = s.find('+')
if i == -1:
i = s.find('-')
if i > 0:
data[3:] = [s[:i], s[i:]]
else:
data.append('') # Dummy tz
if len(data) < 5:
return None
data = data[:5]
[dd, mm, yy, tm, tz] = data
mm = mm.lower()
if mm not in _monthnames:
dd, mm = mm, dd.lower()
if mm not in _monthnames:
return None
mm = _monthnames.index(mm) + 1
if mm > 12:
mm -= 12
if dd[-1] == ',':
dd = dd[:-1]
i = yy.find(':')
if i > 0:
yy, tm = tm, yy
if yy[-1] == ',':
yy = yy[:-1]
if not yy[0].isdigit():
yy, tz = tz, yy
if tm[-1] == ',':
tm = tm[:-1]
tm = tm.split(':')
if len(tm) == 2:
[thh, tmm] = tm
tss = '0'
elif len(tm) == 3:
[thh, tmm, tss] = tm
elif len(tm) == 1 and '.' in tm[0]:
# Some non-compliant MUAs use '.' to separate time elements.
tm = tm[0].split('.')
if len(tm) == 2:
[thh, tmm] = tm
tss = 0
elif len(tm) == 3:
[thh, tmm, tss] = tm
else:
return None
try:
yy = int(yy)
dd = int(dd)
thh = int(thh)
tmm = int(tmm)
tss = int(tss)
except ValueError:
return None
# Check for a yy specified in two-digit format, then convert it to the
# appropriate four-digit format, according to the POSIX standard. RFC 822
# calls for a two-digit yy, but RFC 2822 (which obsoletes RFC 822)
# mandates a 4-digit yy. For more information, see the documentation for
# the time module.
if yy < 100:
# The year is between 1969 and 1999 (inclusive).
if yy > 68:
yy += 1900
# The year is between 2000 and 2068 (inclusive).
else:
yy += 2000
tzoffset = None
tz = tz.upper()
if tz in _timezones:
tzoffset = _timezones[tz]
else:
try:
tzoffset = int(tz)
except ValueError:
pass
if tzoffset==0 and tz.startswith('-'):
tzoffset = None
# Convert a timezone offset into seconds ; -0500 -> -18000
if tzoffset:
if tzoffset < 0:
tzsign = -1
tzoffset = -tzoffset
else:
tzsign = 1
tzoffset = tzsign * ( (tzoffset//100)*3600 + (tzoffset % 100)*60)
# Daylight Saving Time flag is set to -1, since DST is unknown.
return [yy, mm, dd, thh, tmm, tss, 0, 1, -1, tzoffset]
def parsedate(data):
"""Convert a time string to a time tuple."""
t = parsedate_tz(data)
if isinstance(t, tuple):
return t[:9]
else:
return t
def mktime_tz(data):
"""Turn a 10-tuple as returned by parsedate_tz() into a POSIX timestamp."""
if data[9] is None:
# No zone info, so localtime is better assumption than GMT
return time.mktime(data[:8] + (-1,))
else:
t = calendar.timegm(data)
return t - data[9]
def quote(str):
"""Prepare string to be used in a quoted string.
Turns backslash and double quote characters into quoted pairs. These
are the only characters that need to be quoted inside a quoted string.
Does not add the surrounding double quotes.
"""
return str.replace('\\', '\\\\').replace('"', '\\"')
class AddrlistClass(object):
"""Address parser class by Ben Escoto.
To understand what this class does, it helps to have a copy of RFC 2822 in
front of you.
Note: this class interface is deprecated and may be removed in the future.
Use email.utils.AddressList instead.
"""
def __init__(self, field):
"""Initialize a new instance.
`field' is an unparsed address header field, containing
one or more addresses.
"""
self.specials = '()<>@,:;.\"[]'
self.pos = 0
self.LWS = ' \t'
self.CR = '\r\n'
self.FWS = self.LWS + self.CR
self.atomends = self.specials + self.LWS + self.CR
# Note that RFC 2822 now specifies `.' as obs-phrase, meaning that it
# is obsolete syntax. RFC 2822 requires that we recognize obsolete
# syntax, so allow dots in phrases.
self.phraseends = self.atomends.replace('.', '')
self.field = field
self.commentlist = []
def gotonext(self):
"""Skip white space and extract comments."""
wslist = []
while self.pos < len(self.field):
if self.field[self.pos] in self.LWS + '\n\r':
if self.field[self.pos] not in '\n\r':
wslist.append(self.field[self.pos])
self.pos += 1
elif self.field[self.pos] == '(':
self.commentlist.append(self.getcomment())
else:
break
return EMPTYSTRING.join(wslist)
def getaddrlist(self):
"""Parse all addresses.
Returns a list containing all of the addresses.
"""
result = []
while self.pos < len(self.field):
ad = self.getaddress()
if ad:
result += ad
else:
result.append(('', ''))
return result
def getaddress(self):
"""Parse the next address."""
self.commentlist = []
self.gotonext()
oldpos = self.pos
oldcl = self.commentlist
plist = self.getphraselist()
self.gotonext()
returnlist = []
if self.pos >= len(self.field):
# Bad email address technically, no domain.
if plist:
returnlist = [(SPACE.join(self.commentlist), plist[0])]
elif self.field[self.pos] in '.@':
# email address is just an addrspec
# this isn't very efficient since we start over
self.pos = oldpos
self.commentlist = oldcl
addrspec = self.getaddrspec()
returnlist = [(SPACE.join(self.commentlist), addrspec)]
elif self.field[self.pos] == ':':
# address is a group
returnlist = []
fieldlen = len(self.field)
self.pos += 1
while self.pos < len(self.field):
self.gotonext()
if self.pos < fieldlen and self.field[self.pos] == ';':
self.pos += 1
break
returnlist = returnlist + self.getaddress()
elif self.field[self.pos] == '<':
# Address is a phrase then a route addr
routeaddr = self.getrouteaddr()
if self.commentlist:
returnlist = [(SPACE.join(plist) + ' (' +
' '.join(self.commentlist) + ')', routeaddr)]
else:
returnlist = [(SPACE.join(plist), routeaddr)]
else:
if plist:
returnlist = [(SPACE.join(self.commentlist), plist[0])]
elif self.field[self.pos] in self.specials:
self.pos += 1
self.gotonext()
if self.pos < len(self.field) and self.field[self.pos] == ',':
self.pos += 1
return returnlist
def getrouteaddr(self):
"""Parse a route address (Return-path value).
This method just skips all the route stuff and returns the addrspec.
"""
if self.field[self.pos] != '<':
return
expectroute = False
self.pos += 1
self.gotonext()
adlist = ''
while self.pos < len(self.field):
if expectroute:
self.getdomain()
expectroute = False
elif self.field[self.pos] == '>':
self.pos += 1
break
elif self.field[self.pos] == '@':
self.pos += 1
expectroute = True
elif self.field[self.pos] == ':':
self.pos += 1
else:
adlist = self.getaddrspec()
self.pos += 1
break
self.gotonext()
return adlist
def getaddrspec(self):
"""Parse an RFC 2822 addr-spec."""
aslist = []
self.gotonext()
while self.pos < len(self.field):
preserve_ws = True
if self.field[self.pos] == '.':
if aslist and not aslist[-1].strip():
aslist.pop()
aslist.append('.')
self.pos += 1
preserve_ws = False
elif self.field[self.pos] == '"':
aslist.append('"%s"' % quote(self.getquote()))
elif self.field[self.pos] in self.atomends:
if aslist and not aslist[-1].strip():
aslist.pop()
break
else:
aslist.append(self.getatom())
ws = self.gotonext()
if preserve_ws and ws:
aslist.append(ws)
if self.pos >= len(self.field) or self.field[self.pos] != '@':
return EMPTYSTRING.join(aslist)
aslist.append('@')
self.pos += 1
self.gotonext()
return EMPTYSTRING.join(aslist) + self.getdomain()
def getdomain(self):
"""Get the complete domain name from an address."""
sdlist = []
while self.pos < len(self.field):
if self.field[self.pos] in self.LWS:
self.pos += 1
elif self.field[self.pos] == '(':
self.commentlist.append(self.getcomment())
elif self.field[self.pos] == '[':
sdlist.append(self.getdomainliteral())
elif self.field[self.pos] == '.':
self.pos += 1
sdlist.append('.')
elif self.field[self.pos] in self.atomends:
break
else:
sdlist.append(self.getatom())
return EMPTYSTRING.join(sdlist)
def getdelimited(self, beginchar, endchars, allowcomments=True):
"""Parse a header fragment delimited by special characters.
`beginchar' is the start character for the fragment.
If self is not looking at an instance of `beginchar' then
getdelimited returns the empty string.
`endchars' is a sequence of allowable end-delimiting characters.
Parsing stops when one of these is encountered.
If `allowcomments' is non-zero, embedded RFC 2822 comments are allowed
within the parsed fragment.
"""
if self.field[self.pos] != beginchar:
return ''
slist = ['']
quote = False
self.pos += 1
while self.pos < len(self.field):
if quote:
slist.append(self.field[self.pos])
quote = False
elif self.field[self.pos] in endchars:
self.pos += 1
break
elif allowcomments and self.field[self.pos] == '(':
slist.append(self.getcomment())
continue # have already advanced pos from getcomment
elif self.field[self.pos] == '\\':
quote = True
else:
slist.append(self.field[self.pos])
self.pos += 1
return EMPTYSTRING.join(slist)
def getquote(self):
"""Get a quote-delimited fragment from self's field."""
return self.getdelimited('"', '"\r', False)
def getcomment(self):
"""Get a parenthesis-delimited fragment from self's field."""
return self.getdelimited('(', ')\r', True)
def getdomainliteral(self):
"""Parse an RFC 2822 domain-literal."""
return '[%s]' % self.getdelimited('[', ']\r', False)
def getatom(self, atomends=None):
"""Parse an RFC 2822 atom.
Optional atomends specifies a different set of end token delimiters
(the default is to use self.atomends). This is used e.g. in
getphraselist() since phrase endings must not include the `.' (which
is legal in phrases)."""
atomlist = ['']
if atomends is None:
atomends = self.atomends
while self.pos < len(self.field):
if self.field[self.pos] in atomends:
break
else:
atomlist.append(self.field[self.pos])
self.pos += 1
return EMPTYSTRING.join(atomlist)
def getphraselist(self):
"""Parse a sequence of RFC 2822 phrases.
A phrase is a sequence of words, which are in turn either RFC 2822
atoms or quoted-strings. Phrases are canonicalized by squeezing all
runs of continuous whitespace into one space.
"""
plist = []
while self.pos < len(self.field):
if self.field[self.pos] in self.FWS:
self.pos += 1
elif self.field[self.pos] == '"':
plist.append(self.getquote())
elif self.field[self.pos] == '(':
self.commentlist.append(self.getcomment())
elif self.field[self.pos] in self.phraseends:
break
else:
plist.append(self.getatom(self.phraseends))
return plist
class AddressList(AddrlistClass):
"""An AddressList encapsulates a list of parsed RFC 2822 addresses."""
def __init__(self, field):
AddrlistClass.__init__(self, field)
if field:
self.addresslist = self.getaddrlist()
else:
self.addresslist = []
def __len__(self):
return len(self.addresslist)
def __add__(self, other):
# Set union
newaddr = AddressList(None)
newaddr.addresslist = self.addresslist[:]
for x in other.addresslist:
if not x in self.addresslist:
newaddr.addresslist.append(x)
return newaddr
def __iadd__(self, other):
# Set union, in-place
for x in other.addresslist:
if not x in self.addresslist:
self.addresslist.append(x)
return self
def __sub__(self, other):
# Set difference
newaddr = AddressList(None)
for x in self.addresslist:
if not x in other.addresslist:
newaddr.addresslist.append(x)
return newaddr
def __isub__(self, other):
# Set difference, in-place
for x in other.addresslist:
if x in self.addresslist:
self.addresslist.remove(x)
return self
def __getitem__(self, index):
# Make indexing, slices, and 'in' work
return self.addresslist[index]

View file

@ -0,0 +1,365 @@
"""Policy framework for the email package.
Allows fine grained feature control of how the package parses and emits data.
"""
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from future.builtins import super
from future.builtins import str
from future.utils import with_metaclass
import abc
from future.backports.email import header
from future.backports.email import charset as _charset
from future.backports.email.utils import _has_surrogates
__all__ = [
'Policy',
'Compat32',
'compat32',
]
class _PolicyBase(object):
"""Policy Object basic framework.
This class is useless unless subclassed. A subclass should define
class attributes with defaults for any values that are to be
managed by the Policy object. The constructor will then allow
non-default values to be set for these attributes at instance
creation time. The instance will be callable, taking these same
attributes keyword arguments, and returning a new instance
identical to the called instance except for those values changed
by the keyword arguments. Instances may be added, yielding new
instances with any non-default values from the right hand
operand overriding those in the left hand operand. That is,
A + B == A(<non-default values of B>)
The repr of an instance can be used to reconstruct the object
if and only if the repr of the values can be used to reconstruct
those values.
"""
def __init__(self, **kw):
"""Create new Policy, possibly overriding some defaults.
See class docstring for a list of overridable attributes.
"""
for name, value in kw.items():
if hasattr(self, name):
super(_PolicyBase,self).__setattr__(name, value)
else:
raise TypeError(
"{!r} is an invalid keyword argument for {}".format(
name, self.__class__.__name__))
def __repr__(self):
args = [ "{}={!r}".format(name, value)
for name, value in self.__dict__.items() ]
return "{}({})".format(self.__class__.__name__, ', '.join(args))
def clone(self, **kw):
"""Return a new instance with specified attributes changed.
The new instance has the same attribute values as the current object,
except for the changes passed in as keyword arguments.
"""
newpolicy = self.__class__.__new__(self.__class__)
for attr, value in self.__dict__.items():
object.__setattr__(newpolicy, attr, value)
for attr, value in kw.items():
if not hasattr(self, attr):
raise TypeError(
"{!r} is an invalid keyword argument for {}".format(
attr, self.__class__.__name__))
object.__setattr__(newpolicy, attr, value)
return newpolicy
def __setattr__(self, name, value):
if hasattr(self, name):
msg = "{!r} object attribute {!r} is read-only"
else:
msg = "{!r} object has no attribute {!r}"
raise AttributeError(msg.format(self.__class__.__name__, name))
def __add__(self, other):
"""Non-default values from right operand override those from left.
The object returned is a new instance of the subclass.
"""
return self.clone(**other.__dict__)
def _append_doc(doc, added_doc):
doc = doc.rsplit('\n', 1)[0]
added_doc = added_doc.split('\n', 1)[1]
return doc + '\n' + added_doc
def _extend_docstrings(cls):
if cls.__doc__ and cls.__doc__.startswith('+'):
cls.__doc__ = _append_doc(cls.__bases__[0].__doc__, cls.__doc__)
for name, attr in cls.__dict__.items():
if attr.__doc__ and attr.__doc__.startswith('+'):
for c in (c for base in cls.__bases__ for c in base.mro()):
doc = getattr(getattr(c, name), '__doc__')
if doc:
attr.__doc__ = _append_doc(doc, attr.__doc__)
break
return cls
class Policy(with_metaclass(abc.ABCMeta, _PolicyBase)):
r"""Controls for how messages are interpreted and formatted.
Most of the classes and many of the methods in the email package accept
Policy objects as parameters. A Policy object contains a set of values and
functions that control how input is interpreted and how output is rendered.
For example, the parameter 'raise_on_defect' controls whether or not an RFC
violation results in an error being raised or not, while 'max_line_length'
controls the maximum length of output lines when a Message is serialized.
Any valid attribute may be overridden when a Policy is created by passing
it as a keyword argument to the constructor. Policy objects are immutable,
but a new Policy object can be created with only certain values changed by
calling the Policy instance with keyword arguments. Policy objects can
also be added, producing a new Policy object in which the non-default
attributes set in the right hand operand overwrite those specified in the
left operand.
Settable attributes:
raise_on_defect -- If true, then defects should be raised as errors.
Default: False.
linesep -- string containing the value to use as separation
between output lines. Default '\n'.
cte_type -- Type of allowed content transfer encodings
7bit -- ASCII only
8bit -- Content-Transfer-Encoding: 8bit is allowed
Default: 8bit. Also controls the disposition of
(RFC invalid) binary data in headers; see the
documentation of the binary_fold method.
max_line_length -- maximum length of lines, excluding 'linesep',
during serialization. None or 0 means no line
wrapping is done. Default is 78.
"""
raise_on_defect = False
linesep = '\n'
cte_type = '8bit'
max_line_length = 78
def handle_defect(self, obj, defect):
"""Based on policy, either raise defect or call register_defect.
handle_defect(obj, defect)
defect should be a Defect subclass, but in any case must be an
Exception subclass. obj is the object on which the defect should be
registered if it is not raised. If the raise_on_defect is True, the
defect is raised as an error, otherwise the object and the defect are
passed to register_defect.
This method is intended to be called by parsers that discover defects.
The email package parsers always call it with Defect instances.
"""
if self.raise_on_defect:
raise defect
self.register_defect(obj, defect)
def register_defect(self, obj, defect):
"""Record 'defect' on 'obj'.
Called by handle_defect if raise_on_defect is False. This method is
part of the Policy API so that Policy subclasses can implement custom
defect handling. The default implementation calls the append method of
the defects attribute of obj. The objects used by the email package by
default that get passed to this method will always have a defects
attribute with an append method.
"""
obj.defects.append(defect)
def header_max_count(self, name):
"""Return the maximum allowed number of headers named 'name'.
Called when a header is added to a Message object. If the returned
value is not 0 or None, and there are already a number of headers with
the name 'name' equal to the value returned, a ValueError is raised.
Because the default behavior of Message's __setitem__ is to append the
value to the list of headers, it is easy to create duplicate headers
without realizing it. This method allows certain headers to be limited
in the number of instances of that header that may be added to a
Message programmatically. (The limit is not observed by the parser,
which will faithfully produce as many headers as exist in the message
being parsed.)
The default implementation returns None for all header names.
"""
return None
@abc.abstractmethod
def header_source_parse(self, sourcelines):
"""Given a list of linesep terminated strings constituting the lines of
a single header, return the (name, value) tuple that should be stored
in the model. The input lines should retain their terminating linesep
characters. The lines passed in by the email package may contain
surrogateescaped binary data.
"""
raise NotImplementedError
@abc.abstractmethod
def header_store_parse(self, name, value):
"""Given the header name and the value provided by the application
program, return the (name, value) that should be stored in the model.
"""
raise NotImplementedError
@abc.abstractmethod
def header_fetch_parse(self, name, value):
"""Given the header name and the value from the model, return the value
to be returned to the application program that is requesting that
header. The value passed in by the email package may contain
surrogateescaped binary data if the lines were parsed by a BytesParser.
The returned value should not contain any surrogateescaped data.
"""
raise NotImplementedError
@abc.abstractmethod
def fold(self, name, value):
"""Given the header name and the value from the model, return a string
containing linesep characters that implement the folding of the header
according to the policy controls. The value passed in by the email
package may contain surrogateescaped binary data if the lines were
parsed by a BytesParser. The returned value should not contain any
surrogateescaped data.
"""
raise NotImplementedError
@abc.abstractmethod
def fold_binary(self, name, value):
"""Given the header name and the value from the model, return binary
data containing linesep characters that implement the folding of the
header according to the policy controls. The value passed in by the
email package may contain surrogateescaped binary data.
"""
raise NotImplementedError
@_extend_docstrings
class Compat32(Policy):
"""+
This particular policy is the backward compatibility Policy. It
replicates the behavior of the email package version 5.1.
"""
def _sanitize_header(self, name, value):
# If the header value contains surrogates, return a Header using
# the unknown-8bit charset to encode the bytes as encoded words.
if not isinstance(value, str):
# Assume it is already a header object
return value
if _has_surrogates(value):
return header.Header(value, charset=_charset.UNKNOWN8BIT,
header_name=name)
else:
return value
def header_source_parse(self, sourcelines):
"""+
The name is parsed as everything up to the ':' and returned unmodified.
The value is determined by stripping leading whitespace off the
remainder of the first line, joining all subsequent lines together, and
stripping any trailing carriage return or linefeed characters.
"""
name, value = sourcelines[0].split(':', 1)
value = value.lstrip(' \t') + ''.join(sourcelines[1:])
return (name, value.rstrip('\r\n'))
def header_store_parse(self, name, value):
"""+
The name and value are returned unmodified.
"""
return (name, value)
def header_fetch_parse(self, name, value):
"""+
If the value contains binary data, it is converted into a Header object
using the unknown-8bit charset. Otherwise it is returned unmodified.
"""
return self._sanitize_header(name, value)
def fold(self, name, value):
"""+
Headers are folded using the Header folding algorithm, which preserves
existing line breaks in the value, and wraps each resulting line to the
max_line_length. Non-ASCII binary data are CTE encoded using the
unknown-8bit charset.
"""
return self._fold(name, value, sanitize=True)
def fold_binary(self, name, value):
"""+
Headers are folded using the Header folding algorithm, which preserves
existing line breaks in the value, and wraps each resulting line to the
max_line_length. If cte_type is 7bit, non-ascii binary data is CTE
encoded using the unknown-8bit charset. Otherwise the original source
header is used, with its existing line breaks and/or binary data.
"""
folded = self._fold(name, value, sanitize=self.cte_type=='7bit')
return folded.encode('ascii', 'surrogateescape')
def _fold(self, name, value, sanitize):
parts = []
parts.append('%s: ' % name)
if isinstance(value, str):
if _has_surrogates(value):
if sanitize:
h = header.Header(value,
charset=_charset.UNKNOWN8BIT,
header_name=name)
else:
# If we have raw 8bit data in a byte string, we have no idea
# what the encoding is. There is no safe way to split this
# string. If it's ascii-subset, then we could do a normal
# ascii split, but if it's multibyte then we could break the
# string. There's no way to know so the least harm seems to
# be to not split the string and risk it being too long.
parts.append(value)
h = None
else:
h = header.Header(value, header_name=name)
else:
# Assume it is a Header-like object.
h = value
if h is not None:
parts.append(h.encode(linesep=self.linesep,
maxlinelen=self.max_line_length))
parts.append(self.linesep)
return ''.join(parts)
compat32 = Compat32()

View file

@ -0,0 +1,120 @@
# Copyright (C) 2002-2007 Python Software Foundation
# Author: Ben Gertzfield
# Contact: email-sig@python.org
"""Base64 content transfer encoding per RFCs 2045-2047.
This module handles the content transfer encoding method defined in RFC 2045
to encode arbitrary 8-bit data using the three 8-bit bytes in four 7-bit
characters encoding known as Base64.
It is used in the MIME standards for email to attach images, audio, and text
using some 8-bit character sets to messages.
This module provides an interface to encode and decode both headers and bodies
with Base64 encoding.
RFC 2045 defines a method for including character set information in an
`encoded-word' in a header. This method is commonly used for 8-bit real names
in To:, From:, Cc:, etc. fields, as well as Subject: lines.
This module does not do the line wrapping or end-of-line character conversion
necessary for proper internationalized headers; it only does dumb encoding and
decoding. To deal with the various line wrapping issues, use the email.header
module.
"""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import range
from future.builtins import bytes
__all__ = [
'body_decode',
'body_encode',
'decode',
'decodestring',
'header_encode',
'header_length',
]
from base64 import b64encode
from binascii import b2a_base64, a2b_base64
CRLF = '\r\n'
NL = '\n'
EMPTYSTRING = ''
# See also Charset.py
MISC_LEN = 7
# Helpers
def header_length(bytearray):
"""Return the length of s when it is encoded with base64."""
groups_of_3, leftover = divmod(len(bytearray), 3)
# 4 bytes out for each 3 bytes (or nonzero fraction thereof) in.
n = groups_of_3 * 4
if leftover:
n += 4
return n
def header_encode(header_bytes, charset='iso-8859-1'):
"""Encode a single header line with Base64 encoding in a given charset.
charset names the character set to use to encode the header. It defaults
to iso-8859-1. Base64 encoding is defined in RFC 2045.
"""
if not header_bytes:
return ""
if isinstance(header_bytes, str):
header_bytes = header_bytes.encode(charset)
encoded = b64encode(header_bytes).decode("ascii")
return '=?%s?b?%s?=' % (charset, encoded)
def body_encode(s, maxlinelen=76, eol=NL):
r"""Encode a string with base64.
Each line will be wrapped at, at most, maxlinelen characters (defaults to
76 characters).
Each line of encoded text will end with eol, which defaults to "\n". Set
this to "\r\n" if you will be using the result of this function directly
in an email.
"""
if not s:
return s
encvec = []
max_unencoded = maxlinelen * 3 // 4
for i in range(0, len(s), max_unencoded):
# BAW: should encode() inherit b2a_base64()'s dubious behavior in
# adding a newline to the encoded string?
enc = b2a_base64(s[i:i + max_unencoded]).decode("ascii")
if enc.endswith(NL) and eol != NL:
enc = enc[:-1] + eol
encvec.append(enc)
return EMPTYSTRING.join(encvec)
def decode(string):
"""Decode a raw base64 string, returning a bytes object.
This function does not parse a full MIME header value encoded with
base64 (like =?iso-8895-1?b?bmloISBuaWgh?=) -- please use the high
level email.header class for that functionality.
"""
if not string:
return bytes()
elif isinstance(string, str):
return a2b_base64(string.encode('raw-unicode-escape'))
else:
return a2b_base64(string)
# For convenience and backwards compatibility w/ standard base64 module
body_decode = decode
decodestring = decode

View file

@ -0,0 +1,409 @@
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import str
from future.builtins import next
# Copyright (C) 2001-2007 Python Software Foundation
# Author: Ben Gertzfield, Barry Warsaw
# Contact: email-sig@python.org
__all__ = [
'Charset',
'add_alias',
'add_charset',
'add_codec',
]
from functools import partial
from future.backports import email
from future.backports.email import errors
from future.backports.email.encoders import encode_7or8bit
# Flags for types of header encodings
QP = 1 # Quoted-Printable
BASE64 = 2 # Base64
SHORTEST = 3 # the shorter of QP and base64, but only for headers
# In "=?charset?q?hello_world?=", the =?, ?q?, and ?= add up to 7
RFC2047_CHROME_LEN = 7
DEFAULT_CHARSET = 'us-ascii'
UNKNOWN8BIT = 'unknown-8bit'
EMPTYSTRING = ''
# Defaults
CHARSETS = {
# input header enc body enc output conv
'iso-8859-1': (QP, QP, None),
'iso-8859-2': (QP, QP, None),
'iso-8859-3': (QP, QP, None),
'iso-8859-4': (QP, QP, None),
# iso-8859-5 is Cyrillic, and not especially used
# iso-8859-6 is Arabic, also not particularly used
# iso-8859-7 is Greek, QP will not make it readable
# iso-8859-8 is Hebrew, QP will not make it readable
'iso-8859-9': (QP, QP, None),
'iso-8859-10': (QP, QP, None),
# iso-8859-11 is Thai, QP will not make it readable
'iso-8859-13': (QP, QP, None),
'iso-8859-14': (QP, QP, None),
'iso-8859-15': (QP, QP, None),
'iso-8859-16': (QP, QP, None),
'windows-1252':(QP, QP, None),
'viscii': (QP, QP, None),
'us-ascii': (None, None, None),
'big5': (BASE64, BASE64, None),
'gb2312': (BASE64, BASE64, None),
'euc-jp': (BASE64, None, 'iso-2022-jp'),
'shift_jis': (BASE64, None, 'iso-2022-jp'),
'iso-2022-jp': (BASE64, None, None),
'koi8-r': (BASE64, BASE64, None),
'utf-8': (SHORTEST, BASE64, 'utf-8'),
}
# Aliases for other commonly-used names for character sets. Map
# them to the real ones used in email.
ALIASES = {
'latin_1': 'iso-8859-1',
'latin-1': 'iso-8859-1',
'latin_2': 'iso-8859-2',
'latin-2': 'iso-8859-2',
'latin_3': 'iso-8859-3',
'latin-3': 'iso-8859-3',
'latin_4': 'iso-8859-4',
'latin-4': 'iso-8859-4',
'latin_5': 'iso-8859-9',
'latin-5': 'iso-8859-9',
'latin_6': 'iso-8859-10',
'latin-6': 'iso-8859-10',
'latin_7': 'iso-8859-13',
'latin-7': 'iso-8859-13',
'latin_8': 'iso-8859-14',
'latin-8': 'iso-8859-14',
'latin_9': 'iso-8859-15',
'latin-9': 'iso-8859-15',
'latin_10':'iso-8859-16',
'latin-10':'iso-8859-16',
'cp949': 'ks_c_5601-1987',
'euc_jp': 'euc-jp',
'euc_kr': 'euc-kr',
'ascii': 'us-ascii',
}
# Map charsets to their Unicode codec strings.
CODEC_MAP = {
'gb2312': 'eucgb2312_cn',
'big5': 'big5_tw',
# Hack: We don't want *any* conversion for stuff marked us-ascii, as all
# sorts of garbage might be sent to us in the guise of 7-bit us-ascii.
# Let that stuff pass through without conversion to/from Unicode.
'us-ascii': None,
}
# Convenience functions for extending the above mappings
def add_charset(charset, header_enc=None, body_enc=None, output_charset=None):
"""Add character set properties to the global registry.
charset is the input character set, and must be the canonical name of a
character set.
Optional header_enc and body_enc is either Charset.QP for
quoted-printable, Charset.BASE64 for base64 encoding, Charset.SHORTEST for
the shortest of qp or base64 encoding, or None for no encoding. SHORTEST
is only valid for header_enc. It describes how message headers and
message bodies in the input charset are to be encoded. Default is no
encoding.
Optional output_charset is the character set that the output should be
in. Conversions will proceed from input charset, to Unicode, to the
output charset when the method Charset.convert() is called. The default
is to output in the same character set as the input.
Both input_charset and output_charset must have Unicode codec entries in
the module's charset-to-codec mapping; use add_codec(charset, codecname)
to add codecs the module does not know about. See the codecs module's
documentation for more information.
"""
if body_enc == SHORTEST:
raise ValueError('SHORTEST not allowed for body_enc')
CHARSETS[charset] = (header_enc, body_enc, output_charset)
def add_alias(alias, canonical):
"""Add a character set alias.
alias is the alias name, e.g. latin-1
canonical is the character set's canonical name, e.g. iso-8859-1
"""
ALIASES[alias] = canonical
def add_codec(charset, codecname):
"""Add a codec that map characters in the given charset to/from Unicode.
charset is the canonical name of a character set. codecname is the name
of a Python codec, as appropriate for the second argument to the unicode()
built-in, or to the encode() method of a Unicode string.
"""
CODEC_MAP[charset] = codecname
# Convenience function for encoding strings, taking into account
# that they might be unknown-8bit (ie: have surrogate-escaped bytes)
def _encode(string, codec):
string = str(string)
if codec == UNKNOWN8BIT:
return string.encode('ascii', 'surrogateescape')
else:
return string.encode(codec)
class Charset(object):
"""Map character sets to their email properties.
This class provides information about the requirements imposed on email
for a specific character set. It also provides convenience routines for
converting between character sets, given the availability of the
applicable codecs. Given a character set, it will do its best to provide
information on how to use that character set in an email in an
RFC-compliant way.
Certain character sets must be encoded with quoted-printable or base64
when used in email headers or bodies. Certain character sets must be
converted outright, and are not allowed in email. Instances of this
module expose the following information about a character set:
input_charset: The initial character set specified. Common aliases
are converted to their `official' email names (e.g. latin_1
is converted to iso-8859-1). Defaults to 7-bit us-ascii.
header_encoding: If the character set must be encoded before it can be
used in an email header, this attribute will be set to
Charset.QP (for quoted-printable), Charset.BASE64 (for
base64 encoding), or Charset.SHORTEST for the shortest of
QP or BASE64 encoding. Otherwise, it will be None.
body_encoding: Same as header_encoding, but describes the encoding for the
mail message's body, which indeed may be different than the
header encoding. Charset.SHORTEST is not allowed for
body_encoding.
output_charset: Some character sets must be converted before they can be
used in email headers or bodies. If the input_charset is
one of them, this attribute will contain the name of the
charset output will be converted to. Otherwise, it will
be None.
input_codec: The name of the Python codec used to convert the
input_charset to Unicode. If no conversion codec is
necessary, this attribute will be None.
output_codec: The name of the Python codec used to convert Unicode
to the output_charset. If no conversion codec is necessary,
this attribute will have the same value as the input_codec.
"""
def __init__(self, input_charset=DEFAULT_CHARSET):
# RFC 2046, $4.1.2 says charsets are not case sensitive. We coerce to
# unicode because its .lower() is locale insensitive. If the argument
# is already a unicode, we leave it at that, but ensure that the
# charset is ASCII, as the standard (RFC XXX) requires.
try:
if isinstance(input_charset, str):
input_charset.encode('ascii')
else:
input_charset = str(input_charset, 'ascii')
except UnicodeError:
raise errors.CharsetError(input_charset)
input_charset = input_charset.lower()
# Set the input charset after filtering through the aliases
self.input_charset = ALIASES.get(input_charset, input_charset)
# We can try to guess which encoding and conversion to use by the
# charset_map dictionary. Try that first, but let the user override
# it.
henc, benc, conv = CHARSETS.get(self.input_charset,
(SHORTEST, BASE64, None))
if not conv:
conv = self.input_charset
# Set the attributes, allowing the arguments to override the default.
self.header_encoding = henc
self.body_encoding = benc
self.output_charset = ALIASES.get(conv, conv)
# Now set the codecs. If one isn't defined for input_charset,
# guess and try a Unicode codec with the same name as input_codec.
self.input_codec = CODEC_MAP.get(self.input_charset,
self.input_charset)
self.output_codec = CODEC_MAP.get(self.output_charset,
self.output_charset)
def __str__(self):
return self.input_charset.lower()
__repr__ = __str__
def __eq__(self, other):
return str(self) == str(other).lower()
def __ne__(self, other):
return not self.__eq__(other)
def get_body_encoding(self):
"""Return the content-transfer-encoding used for body encoding.
This is either the string `quoted-printable' or `base64' depending on
the encoding used, or it is a function in which case you should call
the function with a single argument, the Message object being
encoded. The function should then set the Content-Transfer-Encoding
header itself to whatever is appropriate.
Returns "quoted-printable" if self.body_encoding is QP.
Returns "base64" if self.body_encoding is BASE64.
Returns conversion function otherwise.
"""
assert self.body_encoding != SHORTEST
if self.body_encoding == QP:
return 'quoted-printable'
elif self.body_encoding == BASE64:
return 'base64'
else:
return encode_7or8bit
def get_output_charset(self):
"""Return the output character set.
This is self.output_charset if that is not None, otherwise it is
self.input_charset.
"""
return self.output_charset or self.input_charset
def header_encode(self, string):
"""Header-encode a string by converting it first to bytes.
The type of encoding (base64 or quoted-printable) will be based on
this charset's `header_encoding`.
:param string: A unicode string for the header. It must be possible
to encode this string to bytes using the character set's
output codec.
:return: The encoded string, with RFC 2047 chrome.
"""
codec = self.output_codec or 'us-ascii'
header_bytes = _encode(string, codec)
# 7bit/8bit encodings return the string unchanged (modulo conversions)
encoder_module = self._get_encoder(header_bytes)
if encoder_module is None:
return string
return encoder_module.header_encode(header_bytes, codec)
def header_encode_lines(self, string, maxlengths):
"""Header-encode a string by converting it first to bytes.
This is similar to `header_encode()` except that the string is fit
into maximum line lengths as given by the argument.
:param string: A unicode string for the header. It must be possible
to encode this string to bytes using the character set's
output codec.
:param maxlengths: Maximum line length iterator. Each element
returned from this iterator will provide the next maximum line
length. This parameter is used as an argument to built-in next()
and should never be exhausted. The maximum line lengths should
not count the RFC 2047 chrome. These line lengths are only a
hint; the splitter does the best it can.
:return: Lines of encoded strings, each with RFC 2047 chrome.
"""
# See which encoding we should use.
codec = self.output_codec or 'us-ascii'
header_bytes = _encode(string, codec)
encoder_module = self._get_encoder(header_bytes)
encoder = partial(encoder_module.header_encode, charset=codec)
# Calculate the number of characters that the RFC 2047 chrome will
# contribute to each line.
charset = self.get_output_charset()
extra = len(charset) + RFC2047_CHROME_LEN
# Now comes the hard part. We must encode bytes but we can't split on
# bytes because some character sets are variable length and each
# encoded word must stand on its own. So the problem is you have to
# encode to bytes to figure out this word's length, but you must split
# on characters. This causes two problems: first, we don't know how
# many octets a specific substring of unicode characters will get
# encoded to, and second, we don't know how many ASCII characters
# those octets will get encoded to. Unless we try it. Which seems
# inefficient. In the interest of being correct rather than fast (and
# in the hope that there will be few encoded headers in any such
# message), brute force it. :(
lines = []
current_line = []
maxlen = next(maxlengths) - extra
for character in string:
current_line.append(character)
this_line = EMPTYSTRING.join(current_line)
length = encoder_module.header_length(_encode(this_line, charset))
if length > maxlen:
# This last character doesn't fit so pop it off.
current_line.pop()
# Does nothing fit on the first line?
if not lines and not current_line:
lines.append(None)
else:
separator = (' ' if lines else '')
joined_line = EMPTYSTRING.join(current_line)
header_bytes = _encode(joined_line, codec)
lines.append(encoder(header_bytes))
current_line = [character]
maxlen = next(maxlengths) - extra
joined_line = EMPTYSTRING.join(current_line)
header_bytes = _encode(joined_line, codec)
lines.append(encoder(header_bytes))
return lines
def _get_encoder(self, header_bytes):
if self.header_encoding == BASE64:
return email.base64mime
elif self.header_encoding == QP:
return email.quoprimime
elif self.header_encoding == SHORTEST:
len64 = email.base64mime.header_length(header_bytes)
lenqp = email.quoprimime.header_length(header_bytes)
if len64 < lenqp:
return email.base64mime
else:
return email.quoprimime
else:
return None
def body_encode(self, string):
"""Body-encode a string by converting it first to bytes.
The type of encoding (base64 or quoted-printable) will be based on
self.body_encoding. If body_encoding is None, we assume the
output charset is a 7bit encoding, so re-encoding the decoded
string using the ascii codec produces the correct string version
of the content.
"""
if not string:
return string
if self.body_encoding is BASE64:
if isinstance(string, str):
string = string.encode(self.output_charset)
return email.base64mime.body_encode(string)
elif self.body_encoding is QP:
# quopromime.body_encode takes a string, but operates on it as if
# it were a list of byte codes. For a (minimal) history on why
# this is so, see changeset 0cf700464177. To correctly encode a
# character set, then, we must turn it into pseudo bytes via the
# latin1 charset, which will encode any byte as a single code point
# between 0 and 255, which is what body_encode is expecting.
if isinstance(string, str):
string = string.encode(self.output_charset)
string = string.decode('latin1')
return email.quoprimime.body_encode(string)
else:
if isinstance(string, str):
string = string.encode(self.output_charset).decode('ascii')
return string

View file

@ -0,0 +1,90 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Encodings and related functions."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import str
__all__ = [
'encode_7or8bit',
'encode_base64',
'encode_noop',
'encode_quopri',
]
try:
from base64 import encodebytes as _bencode
except ImportError:
# Py2 compatibility. TODO: test this!
from base64 import encodestring as _bencode
from quopri import encodestring as _encodestring
def _qencode(s):
enc = _encodestring(s, quotetabs=True)
# Must encode spaces, which quopri.encodestring() doesn't do
return enc.replace(' ', '=20')
def encode_base64(msg):
"""Encode the message's payload in Base64.
Also, add an appropriate Content-Transfer-Encoding header.
"""
orig = msg.get_payload()
encdata = str(_bencode(orig), 'ascii')
msg.set_payload(encdata)
msg['Content-Transfer-Encoding'] = 'base64'
def encode_quopri(msg):
"""Encode the message's payload in quoted-printable.
Also, add an appropriate Content-Transfer-Encoding header.
"""
orig = msg.get_payload()
encdata = _qencode(orig)
msg.set_payload(encdata)
msg['Content-Transfer-Encoding'] = 'quoted-printable'
def encode_7or8bit(msg):
"""Set the Content-Transfer-Encoding header to 7bit or 8bit."""
orig = msg.get_payload()
if orig is None:
# There's no payload. For backwards compatibility we use 7bit
msg['Content-Transfer-Encoding'] = '7bit'
return
# We play a trick to make this go fast. If encoding/decode to ASCII
# succeeds, we know the data must be 7bit, otherwise treat it as 8bit.
try:
if isinstance(orig, str):
orig.encode('ascii')
else:
orig.decode('ascii')
except UnicodeError:
charset = msg.get_charset()
output_cset = charset and charset.output_charset
# iso-2022-* is non-ASCII but encodes to a 7-bit representation
if output_cset and output_cset.lower().startswith('iso-2022-'):
msg['Content-Transfer-Encoding'] = '7bit'
else:
msg['Content-Transfer-Encoding'] = '8bit'
else:
msg['Content-Transfer-Encoding'] = '7bit'
if not isinstance(orig, str):
msg.set_payload(orig.decode('ascii', 'surrogateescape'))
def encode_noop(msg):
"""Do nothing."""
# Well, not quite *nothing*: in Python3 we have to turn bytes into a string
# in our internal surrogateescaped form in order to keep the model
# consistent.
orig = msg.get_payload()
if not isinstance(orig, str):
msg.set_payload(orig.decode('ascii', 'surrogateescape'))

View file

@ -0,0 +1,111 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""email package exception classes."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import super
class MessageError(Exception):
"""Base class for errors in the email package."""
class MessageParseError(MessageError):
"""Base class for message parsing errors."""
class HeaderParseError(MessageParseError):
"""Error while parsing headers."""
class BoundaryError(MessageParseError):
"""Couldn't find terminating boundary."""
class MultipartConversionError(MessageError, TypeError):
"""Conversion to a multipart is prohibited."""
class CharsetError(MessageError):
"""An illegal charset was given."""
# These are parsing defects which the parser was able to work around.
class MessageDefect(ValueError):
"""Base class for a message defect."""
def __init__(self, line=None):
if line is not None:
super().__init__(line)
self.line = line
class NoBoundaryInMultipartDefect(MessageDefect):
"""A message claimed to be a multipart but had no boundary parameter."""
class StartBoundaryNotFoundDefect(MessageDefect):
"""The claimed start boundary was never found."""
class CloseBoundaryNotFoundDefect(MessageDefect):
"""A start boundary was found, but not the corresponding close boundary."""
class FirstHeaderLineIsContinuationDefect(MessageDefect):
"""A message had a continuation line as its first header line."""
class MisplacedEnvelopeHeaderDefect(MessageDefect):
"""A 'Unix-from' header was found in the middle of a header block."""
class MissingHeaderBodySeparatorDefect(MessageDefect):
"""Found line with no leading whitespace and no colon before blank line."""
# XXX: backward compatibility, just in case (it was never emitted).
MalformedHeaderDefect = MissingHeaderBodySeparatorDefect
class MultipartInvariantViolationDefect(MessageDefect):
"""A message claimed to be a multipart but no subparts were found."""
class InvalidMultipartContentTransferEncodingDefect(MessageDefect):
"""An invalid content transfer encoding was set on the multipart itself."""
class UndecodableBytesDefect(MessageDefect):
"""Header contained bytes that could not be decoded"""
class InvalidBase64PaddingDefect(MessageDefect):
"""base64 encoded sequence had an incorrect length"""
class InvalidBase64CharactersDefect(MessageDefect):
"""base64 encoded sequence had characters not in base64 alphabet"""
# These errors are specific to header parsing.
class HeaderDefect(MessageDefect):
"""Base class for a header defect."""
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
class InvalidHeaderDefect(HeaderDefect):
"""Header is not valid, message gives details."""
class HeaderMissingRequiredValue(HeaderDefect):
"""A header that must have a value had none"""
class NonPrintableDefect(HeaderDefect):
"""ASCII characters outside the ascii-printable range found"""
def __init__(self, non_printables):
super().__init__(non_printables)
self.non_printables = non_printables
def __str__(self):
return ("the following ASCII non-printables found in header: "
"{}".format(self.non_printables))
class ObsoleteHeaderDefect(HeaderDefect):
"""Header uses syntax declared obsolete by RFC 5322"""
class NonASCIILocalPartDefect(HeaderDefect):
"""local_part contains non-ASCII characters"""
# This defect only occurs during unicode parsing, not when
# parsing messages decoded from binary.

View file

@ -0,0 +1,525 @@
# Copyright (C) 2004-2006 Python Software Foundation
# Authors: Baxter, Wouters and Warsaw
# Contact: email-sig@python.org
"""FeedParser - An email feed parser.
The feed parser implements an interface for incrementally parsing an email
message, line by line. This has advantages for certain applications, such as
those reading email messages off a socket.
FeedParser.feed() is the primary interface for pushing new data into the
parser. It returns when there's nothing more it can do with the available
data. When you have no more data to push into the parser, call .close().
This completes the parsing and returns the root message object.
The other advantage of this parser is that it will never raise a parsing
exception. Instead, when it finds something unexpected, it adds a 'defect' to
the current message. Defects are just instances that live on the message
object's .defects attribute.
"""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import object, range, super
from future.utils import implements_iterator, PY3
__all__ = ['FeedParser', 'BytesFeedParser']
import re
from future.backports.email import errors
from future.backports.email import message
from future.backports.email._policybase import compat32
NLCRE = re.compile('\r\n|\r|\n')
NLCRE_bol = re.compile('(\r\n|\r|\n)')
NLCRE_eol = re.compile('(\r\n|\r|\n)\Z')
NLCRE_crack = re.compile('(\r\n|\r|\n)')
# RFC 2822 $3.6.8 Optional fields. ftext is %d33-57 / %d59-126, Any character
# except controls, SP, and ":".
headerRE = re.compile(r'^(From |[\041-\071\073-\176]{1,}:|[\t ])')
EMPTYSTRING = ''
NL = '\n'
NeedMoreData = object()
# @implements_iterator
class BufferedSubFile(object):
"""A file-ish object that can have new data loaded into it.
You can also push and pop line-matching predicates onto a stack. When the
current predicate matches the current line, a false EOF response
(i.e. empty string) is returned instead. This lets the parser adhere to a
simple abstraction -- it parses until EOF closes the current message.
"""
def __init__(self):
# The last partial line pushed into this object.
self._partial = ''
# The list of full, pushed lines, in reverse order
self._lines = []
# The stack of false-EOF checking predicates.
self._eofstack = []
# A flag indicating whether the file has been closed or not.
self._closed = False
def push_eof_matcher(self, pred):
self._eofstack.append(pred)
def pop_eof_matcher(self):
return self._eofstack.pop()
def close(self):
# Don't forget any trailing partial line.
self._lines.append(self._partial)
self._partial = ''
self._closed = True
def readline(self):
if not self._lines:
if self._closed:
return ''
return NeedMoreData
# Pop the line off the stack and see if it matches the current
# false-EOF predicate.
line = self._lines.pop()
# RFC 2046, section 5.1.2 requires us to recognize outer level
# boundaries at any level of inner nesting. Do this, but be sure it's
# in the order of most to least nested.
for ateof in self._eofstack[::-1]:
if ateof(line):
# We're at the false EOF. But push the last line back first.
self._lines.append(line)
return ''
return line
def unreadline(self, line):
# Let the consumer push a line back into the buffer.
assert line is not NeedMoreData
self._lines.append(line)
def push(self, data):
"""Push some new data into this object."""
# Handle any previous leftovers
data, self._partial = self._partial + data, ''
# Crack into lines, but preserve the newlines on the end of each
parts = NLCRE_crack.split(data)
# The *ahem* interesting behaviour of re.split when supplied grouping
# parentheses is that the last element of the resulting list is the
# data after the final RE. In the case of a NL/CR terminated string,
# this is the empty string.
self._partial = parts.pop()
#GAN 29Mar09 bugs 1555570, 1721862 Confusion at 8K boundary ending with \r:
# is there a \n to follow later?
if not self._partial and parts and parts[-1].endswith('\r'):
self._partial = parts.pop(-2)+parts.pop()
# parts is a list of strings, alternating between the line contents
# and the eol character(s). Gather up a list of lines after
# re-attaching the newlines.
lines = []
for i in range(len(parts) // 2):
lines.append(parts[i*2] + parts[i*2+1])
self.pushlines(lines)
def pushlines(self, lines):
# Reverse and insert at the front of the lines.
self._lines[:0] = lines[::-1]
def __iter__(self):
return self
def __next__(self):
line = self.readline()
if line == '':
raise StopIteration
return line
class FeedParser(object):
"""A feed-style parser of email."""
def __init__(self, _factory=message.Message, **_3to2kwargs):
if 'policy' in _3to2kwargs: policy = _3to2kwargs['policy']; del _3to2kwargs['policy']
else: policy = compat32
"""_factory is called with no arguments to create a new message obj
The policy keyword specifies a policy object that controls a number of
aspects of the parser's operation. The default policy maintains
backward compatibility.
"""
self._factory = _factory
self.policy = policy
try:
_factory(policy=self.policy)
self._factory_kwds = lambda: {'policy': self.policy}
except TypeError:
# Assume this is an old-style factory
self._factory_kwds = lambda: {}
self._input = BufferedSubFile()
self._msgstack = []
if PY3:
self._parse = self._parsegen().__next__
else:
self._parse = self._parsegen().next
self._cur = None
self._last = None
self._headersonly = False
# Non-public interface for supporting Parser's headersonly flag
def _set_headersonly(self):
self._headersonly = True
def feed(self, data):
"""Push more data into the parser."""
self._input.push(data)
self._call_parse()
def _call_parse(self):
try:
self._parse()
except StopIteration:
pass
def close(self):
"""Parse all remaining data and return the root message object."""
self._input.close()
self._call_parse()
root = self._pop_message()
assert not self._msgstack
# Look for final set of defects
if root.get_content_maintype() == 'multipart' \
and not root.is_multipart():
defect = errors.MultipartInvariantViolationDefect()
self.policy.handle_defect(root, defect)
return root
def _new_message(self):
msg = self._factory(**self._factory_kwds())
if self._cur and self._cur.get_content_type() == 'multipart/digest':
msg.set_default_type('message/rfc822')
if self._msgstack:
self._msgstack[-1].attach(msg)
self._msgstack.append(msg)
self._cur = msg
self._last = msg
def _pop_message(self):
retval = self._msgstack.pop()
if self._msgstack:
self._cur = self._msgstack[-1]
else:
self._cur = None
return retval
def _parsegen(self):
# Create a new message and start by parsing headers.
self._new_message()
headers = []
# Collect the headers, searching for a line that doesn't match the RFC
# 2822 header or continuation pattern (including an empty line).
for line in self._input:
if line is NeedMoreData:
yield NeedMoreData
continue
if not headerRE.match(line):
# If we saw the RFC defined header/body separator
# (i.e. newline), just throw it away. Otherwise the line is
# part of the body so push it back.
if not NLCRE.match(line):
defect = errors.MissingHeaderBodySeparatorDefect()
self.policy.handle_defect(self._cur, defect)
self._input.unreadline(line)
break
headers.append(line)
# Done with the headers, so parse them and figure out what we're
# supposed to see in the body of the message.
self._parse_headers(headers)
# Headers-only parsing is a backwards compatibility hack, which was
# necessary in the older parser, which could raise errors. All
# remaining lines in the input are thrown into the message body.
if self._headersonly:
lines = []
while True:
line = self._input.readline()
if line is NeedMoreData:
yield NeedMoreData
continue
if line == '':
break
lines.append(line)
self._cur.set_payload(EMPTYSTRING.join(lines))
return
if self._cur.get_content_type() == 'message/delivery-status':
# message/delivery-status contains blocks of headers separated by
# a blank line. We'll represent each header block as a separate
# nested message object, but the processing is a bit different
# than standard message/* types because there is no body for the
# nested messages. A blank line separates the subparts.
while True:
self._input.push_eof_matcher(NLCRE.match)
for retval in self._parsegen():
if retval is NeedMoreData:
yield NeedMoreData
continue
break
msg = self._pop_message()
# We need to pop the EOF matcher in order to tell if we're at
# the end of the current file, not the end of the last block
# of message headers.
self._input.pop_eof_matcher()
# The input stream must be sitting at the newline or at the
# EOF. We want to see if we're at the end of this subpart, so
# first consume the blank line, then test the next line to see
# if we're at this subpart's EOF.
while True:
line = self._input.readline()
if line is NeedMoreData:
yield NeedMoreData
continue
break
while True:
line = self._input.readline()
if line is NeedMoreData:
yield NeedMoreData
continue
break
if line == '':
break
# Not at EOF so this is a line we're going to need.
self._input.unreadline(line)
return
if self._cur.get_content_maintype() == 'message':
# The message claims to be a message/* type, then what follows is
# another RFC 2822 message.
for retval in self._parsegen():
if retval is NeedMoreData:
yield NeedMoreData
continue
break
self._pop_message()
return
if self._cur.get_content_maintype() == 'multipart':
boundary = self._cur.get_boundary()
if boundary is None:
# The message /claims/ to be a multipart but it has not
# defined a boundary. That's a problem which we'll handle by
# reading everything until the EOF and marking the message as
# defective.
defect = errors.NoBoundaryInMultipartDefect()
self.policy.handle_defect(self._cur, defect)
lines = []
for line in self._input:
if line is NeedMoreData:
yield NeedMoreData
continue
lines.append(line)
self._cur.set_payload(EMPTYSTRING.join(lines))
return
# Make sure a valid content type was specified per RFC 2045:6.4.
if (self._cur.get('content-transfer-encoding', '8bit').lower()
not in ('7bit', '8bit', 'binary')):
defect = errors.InvalidMultipartContentTransferEncodingDefect()
self.policy.handle_defect(self._cur, defect)
# Create a line match predicate which matches the inter-part
# boundary as well as the end-of-multipart boundary. Don't push
# this onto the input stream until we've scanned past the
# preamble.
separator = '--' + boundary
boundaryre = re.compile(
'(?P<sep>' + re.escape(separator) +
r')(?P<end>--)?(?P<ws>[ \t]*)(?P<linesep>\r\n|\r|\n)?$')
capturing_preamble = True
preamble = []
linesep = False
close_boundary_seen = False
while True:
line = self._input.readline()
if line is NeedMoreData:
yield NeedMoreData
continue
if line == '':
break
mo = boundaryre.match(line)
if mo:
# If we're looking at the end boundary, we're done with
# this multipart. If there was a newline at the end of
# the closing boundary, then we need to initialize the
# epilogue with the empty string (see below).
if mo.group('end'):
close_boundary_seen = True
linesep = mo.group('linesep')
break
# We saw an inter-part boundary. Were we in the preamble?
if capturing_preamble:
if preamble:
# According to RFC 2046, the last newline belongs
# to the boundary.
lastline = preamble[-1]
eolmo = NLCRE_eol.search(lastline)
if eolmo:
preamble[-1] = lastline[:-len(eolmo.group(0))]
self._cur.preamble = EMPTYSTRING.join(preamble)
capturing_preamble = False
self._input.unreadline(line)
continue
# We saw a boundary separating two parts. Consume any
# multiple boundary lines that may be following. Our
# interpretation of RFC 2046 BNF grammar does not produce
# body parts within such double boundaries.
while True:
line = self._input.readline()
if line is NeedMoreData:
yield NeedMoreData
continue
mo = boundaryre.match(line)
if not mo:
self._input.unreadline(line)
break
# Recurse to parse this subpart; the input stream points
# at the subpart's first line.
self._input.push_eof_matcher(boundaryre.match)
for retval in self._parsegen():
if retval is NeedMoreData:
yield NeedMoreData
continue
break
# Because of RFC 2046, the newline preceding the boundary
# separator actually belongs to the boundary, not the
# previous subpart's payload (or epilogue if the previous
# part is a multipart).
if self._last.get_content_maintype() == 'multipart':
epilogue = self._last.epilogue
if epilogue == '':
self._last.epilogue = None
elif epilogue is not None:
mo = NLCRE_eol.search(epilogue)
if mo:
end = len(mo.group(0))
self._last.epilogue = epilogue[:-end]
else:
payload = self._last._payload
if isinstance(payload, str):
mo = NLCRE_eol.search(payload)
if mo:
payload = payload[:-len(mo.group(0))]
self._last._payload = payload
self._input.pop_eof_matcher()
self._pop_message()
# Set the multipart up for newline cleansing, which will
# happen if we're in a nested multipart.
self._last = self._cur
else:
# I think we must be in the preamble
assert capturing_preamble
preamble.append(line)
# We've seen either the EOF or the end boundary. If we're still
# capturing the preamble, we never saw the start boundary. Note
# that as a defect and store the captured text as the payload.
if capturing_preamble:
defect = errors.StartBoundaryNotFoundDefect()
self.policy.handle_defect(self._cur, defect)
self._cur.set_payload(EMPTYSTRING.join(preamble))
epilogue = []
for line in self._input:
if line is NeedMoreData:
yield NeedMoreData
continue
self._cur.epilogue = EMPTYSTRING.join(epilogue)
return
# If we're not processing the preamble, then we might have seen
# EOF without seeing that end boundary...that is also a defect.
if not close_boundary_seen:
defect = errors.CloseBoundaryNotFoundDefect()
self.policy.handle_defect(self._cur, defect)
return
# Everything from here to the EOF is epilogue. If the end boundary
# ended in a newline, we'll need to make sure the epilogue isn't
# None
if linesep:
epilogue = ['']
else:
epilogue = []
for line in self._input:
if line is NeedMoreData:
yield NeedMoreData
continue
epilogue.append(line)
# Any CRLF at the front of the epilogue is not technically part of
# the epilogue. Also, watch out for an empty string epilogue,
# which means a single newline.
if epilogue:
firstline = epilogue[0]
bolmo = NLCRE_bol.match(firstline)
if bolmo:
epilogue[0] = firstline[len(bolmo.group(0)):]
self._cur.epilogue = EMPTYSTRING.join(epilogue)
return
# Otherwise, it's some non-multipart type, so the entire rest of the
# file contents becomes the payload.
lines = []
for line in self._input:
if line is NeedMoreData:
yield NeedMoreData
continue
lines.append(line)
self._cur.set_payload(EMPTYSTRING.join(lines))
def _parse_headers(self, lines):
# Passed a list of lines that make up the headers for the current msg
lastheader = ''
lastvalue = []
for lineno, line in enumerate(lines):
# Check for continuation
if line[0] in ' \t':
if not lastheader:
# The first line of the headers was a continuation. This
# is illegal, so let's note the defect, store the illegal
# line, and ignore it for purposes of headers.
defect = errors.FirstHeaderLineIsContinuationDefect(line)
self.policy.handle_defect(self._cur, defect)
continue
lastvalue.append(line)
continue
if lastheader:
self._cur.set_raw(*self.policy.header_source_parse(lastvalue))
lastheader, lastvalue = '', []
# Check for envelope header, i.e. unix-from
if line.startswith('From '):
if lineno == 0:
# Strip off the trailing newline
mo = NLCRE_eol.search(line)
if mo:
line = line[:-len(mo.group(0))]
self._cur.set_unixfrom(line)
continue
elif lineno == len(lines) - 1:
# Something looking like a unix-from at the end - it's
# probably the first line of the body, so push back the
# line and stop.
self._input.unreadline(line)
return
else:
# Weirdly placed unix-from line. Note this as a defect
# and ignore it.
defect = errors.MisplacedEnvelopeHeaderDefect(line)
self._cur.defects.append(defect)
continue
# Split the line on the colon separating field name from value.
# There will always be a colon, because if there wasn't the part of
# the parser that calls us would have started parsing the body.
i = line.find(':')
assert i>0, "_parse_headers fed line with no : and no leading WS"
lastheader = line[:i]
lastvalue = [line]
# Done with all the lines, so handle the last header.
if lastheader:
self._cur.set_raw(*self.policy.header_source_parse(lastvalue))
class BytesFeedParser(FeedParser):
"""Like FeedParser, but feed accepts bytes."""
def feed(self, data):
super().feed(data.decode('ascii', 'surrogateescape'))

View file

@ -0,0 +1,498 @@
# Copyright (C) 2001-2010 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Classes to generate plain text from a message object tree."""
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import super
from future.builtins import str
__all__ = ['Generator', 'DecodedGenerator', 'BytesGenerator']
import re
import sys
import time
import random
import warnings
from io import StringIO, BytesIO
from future.backports.email._policybase import compat32
from future.backports.email.header import Header
from future.backports.email.utils import _has_surrogates
import future.backports.email.charset as _charset
UNDERSCORE = '_'
NL = '\n' # XXX: no longer used by the code below.
fcre = re.compile(r'^From ', re.MULTILINE)
class Generator(object):
"""Generates output from a Message object tree.
This basic generator writes the message to the given file object as plain
text.
"""
#
# Public interface
#
def __init__(self, outfp, mangle_from_=True, maxheaderlen=None, **_3to2kwargs):
if 'policy' in _3to2kwargs: policy = _3to2kwargs['policy']; del _3to2kwargs['policy']
else: policy = None
"""Create the generator for message flattening.
outfp is the output file-like object for writing the message to. It
must have a write() method.
Optional mangle_from_ is a flag that, when True (the default), escapes
From_ lines in the body of the message by putting a `>' in front of
them.
Optional maxheaderlen specifies the longest length for a non-continued
header. When a header line is longer (in characters, with tabs
expanded to 8 spaces) than maxheaderlen, the header will split as
defined in the Header class. Set maxheaderlen to zero to disable
header wrapping. The default is 78, as recommended (but not required)
by RFC 2822.
The policy keyword specifies a policy object that controls a number of
aspects of the generator's operation. The default policy maintains
backward compatibility.
"""
self._fp = outfp
self._mangle_from_ = mangle_from_
self.maxheaderlen = maxheaderlen
self.policy = policy
def write(self, s):
# Just delegate to the file object
self._fp.write(s)
def flatten(self, msg, unixfrom=False, linesep=None):
r"""Print the message object tree rooted at msg to the output file
specified when the Generator instance was created.
unixfrom is a flag that forces the printing of a Unix From_ delimiter
before the first object in the message tree. If the original message
has no From_ delimiter, a `standard' one is crafted. By default, this
is False to inhibit the printing of any From_ delimiter.
Note that for subobjects, no From_ line is printed.
linesep specifies the characters used to indicate a new line in
the output. The default value is determined by the policy.
"""
# We use the _XXX constants for operating on data that comes directly
# from the msg, and _encoded_XXX constants for operating on data that
# has already been converted (to bytes in the BytesGenerator) and
# inserted into a temporary buffer.
policy = msg.policy if self.policy is None else self.policy
if linesep is not None:
policy = policy.clone(linesep=linesep)
if self.maxheaderlen is not None:
policy = policy.clone(max_line_length=self.maxheaderlen)
self._NL = policy.linesep
self._encoded_NL = self._encode(self._NL)
self._EMPTY = ''
self._encoded_EMTPY = self._encode('')
# Because we use clone (below) when we recursively process message
# subparts, and because clone uses the computed policy (not None),
# submessages will automatically get set to the computed policy when
# they are processed by this code.
old_gen_policy = self.policy
old_msg_policy = msg.policy
try:
self.policy = policy
msg.policy = policy
if unixfrom:
ufrom = msg.get_unixfrom()
if not ufrom:
ufrom = 'From nobody ' + time.ctime(time.time())
self.write(ufrom + self._NL)
self._write(msg)
finally:
self.policy = old_gen_policy
msg.policy = old_msg_policy
def clone(self, fp):
"""Clone this generator with the exact same options."""
return self.__class__(fp,
self._mangle_from_,
None, # Use policy setting, which we've adjusted
policy=self.policy)
#
# Protected interface - undocumented ;/
#
# Note that we use 'self.write' when what we are writing is coming from
# the source, and self._fp.write when what we are writing is coming from a
# buffer (because the Bytes subclass has already had a chance to transform
# the data in its write method in that case). This is an entirely
# pragmatic split determined by experiment; we could be more general by
# always using write and having the Bytes subclass write method detect when
# it has already transformed the input; but, since this whole thing is a
# hack anyway this seems good enough.
# Similarly, we have _XXX and _encoded_XXX attributes that are used on
# source and buffer data, respectively.
_encoded_EMPTY = ''
def _new_buffer(self):
# BytesGenerator overrides this to return BytesIO.
return StringIO()
def _encode(self, s):
# BytesGenerator overrides this to encode strings to bytes.
return s
def _write_lines(self, lines):
# We have to transform the line endings.
if not lines:
return
lines = lines.splitlines(True)
for line in lines[:-1]:
self.write(line.rstrip('\r\n'))
self.write(self._NL)
laststripped = lines[-1].rstrip('\r\n')
self.write(laststripped)
if len(lines[-1]) != len(laststripped):
self.write(self._NL)
def _write(self, msg):
# We can't write the headers yet because of the following scenario:
# say a multipart message includes the boundary string somewhere in
# its body. We'd have to calculate the new boundary /before/ we write
# the headers so that we can write the correct Content-Type:
# parameter.
#
# The way we do this, so as to make the _handle_*() methods simpler,
# is to cache any subpart writes into a buffer. The we write the
# headers and the buffer contents. That way, subpart handlers can
# Do The Right Thing, and can still modify the Content-Type: header if
# necessary.
oldfp = self._fp
try:
self._fp = sfp = self._new_buffer()
self._dispatch(msg)
finally:
self._fp = oldfp
# Write the headers. First we see if the message object wants to
# handle that itself. If not, we'll do it generically.
meth = getattr(msg, '_write_headers', None)
if meth is None:
self._write_headers(msg)
else:
meth(self)
self._fp.write(sfp.getvalue())
def _dispatch(self, msg):
# Get the Content-Type: for the message, then try to dispatch to
# self._handle_<maintype>_<subtype>(). If there's no handler for the
# full MIME type, then dispatch to self._handle_<maintype>(). If
# that's missing too, then dispatch to self._writeBody().
main = msg.get_content_maintype()
sub = msg.get_content_subtype()
specific = UNDERSCORE.join((main, sub)).replace('-', '_')
meth = getattr(self, '_handle_' + specific, None)
if meth is None:
generic = main.replace('-', '_')
meth = getattr(self, '_handle_' + generic, None)
if meth is None:
meth = self._writeBody
meth(msg)
#
# Default handlers
#
def _write_headers(self, msg):
for h, v in msg.raw_items():
self.write(self.policy.fold(h, v))
# A blank line always separates headers from body
self.write(self._NL)
#
# Handlers for writing types and subtypes
#
def _handle_text(self, msg):
payload = msg.get_payload()
if payload is None:
return
if not isinstance(payload, str):
raise TypeError('string payload expected: %s' % type(payload))
if _has_surrogates(msg._payload):
charset = msg.get_param('charset')
if charset is not None:
del msg['content-transfer-encoding']
msg.set_payload(payload, charset)
payload = msg.get_payload()
if self._mangle_from_:
payload = fcre.sub('>From ', payload)
self._write_lines(payload)
# Default body handler
_writeBody = _handle_text
def _handle_multipart(self, msg):
# The trick here is to write out each part separately, merge them all
# together, and then make sure that the boundary we've chosen isn't
# present in the payload.
msgtexts = []
subparts = msg.get_payload()
if subparts is None:
subparts = []
elif isinstance(subparts, str):
# e.g. a non-strict parse of a message with no starting boundary.
self.write(subparts)
return
elif not isinstance(subparts, list):
# Scalar payload
subparts = [subparts]
for part in subparts:
s = self._new_buffer()
g = self.clone(s)
g.flatten(part, unixfrom=False, linesep=self._NL)
msgtexts.append(s.getvalue())
# BAW: What about boundaries that are wrapped in double-quotes?
boundary = msg.get_boundary()
if not boundary:
# Create a boundary that doesn't appear in any of the
# message texts.
alltext = self._encoded_NL.join(msgtexts)
boundary = self._make_boundary(alltext)
msg.set_boundary(boundary)
# If there's a preamble, write it out, with a trailing CRLF
if msg.preamble is not None:
if self._mangle_from_:
preamble = fcre.sub('>From ', msg.preamble)
else:
preamble = msg.preamble
self._write_lines(preamble)
self.write(self._NL)
# dash-boundary transport-padding CRLF
self.write('--' + boundary + self._NL)
# body-part
if msgtexts:
self._fp.write(msgtexts.pop(0))
# *encapsulation
# --> delimiter transport-padding
# --> CRLF body-part
for body_part in msgtexts:
# delimiter transport-padding CRLF
self.write(self._NL + '--' + boundary + self._NL)
# body-part
self._fp.write(body_part)
# close-delimiter transport-padding
self.write(self._NL + '--' + boundary + '--')
if msg.epilogue is not None:
self.write(self._NL)
if self._mangle_from_:
epilogue = fcre.sub('>From ', msg.epilogue)
else:
epilogue = msg.epilogue
self._write_lines(epilogue)
def _handle_multipart_signed(self, msg):
# The contents of signed parts has to stay unmodified in order to keep
# the signature intact per RFC1847 2.1, so we disable header wrapping.
# RDM: This isn't enough to completely preserve the part, but it helps.
p = self.policy
self.policy = p.clone(max_line_length=0)
try:
self._handle_multipart(msg)
finally:
self.policy = p
def _handle_message_delivery_status(self, msg):
# We can't just write the headers directly to self's file object
# because this will leave an extra newline between the last header
# block and the boundary. Sigh.
blocks = []
for part in msg.get_payload():
s = self._new_buffer()
g = self.clone(s)
g.flatten(part, unixfrom=False, linesep=self._NL)
text = s.getvalue()
lines = text.split(self._encoded_NL)
# Strip off the unnecessary trailing empty line
if lines and lines[-1] == self._encoded_EMPTY:
blocks.append(self._encoded_NL.join(lines[:-1]))
else:
blocks.append(text)
# Now join all the blocks with an empty line. This has the lovely
# effect of separating each block with an empty line, but not adding
# an extra one after the last one.
self._fp.write(self._encoded_NL.join(blocks))
def _handle_message(self, msg):
s = self._new_buffer()
g = self.clone(s)
# The payload of a message/rfc822 part should be a multipart sequence
# of length 1. The zeroth element of the list should be the Message
# object for the subpart. Extract that object, stringify it, and
# write it out.
# Except, it turns out, when it's a string instead, which happens when
# and only when HeaderParser is used on a message of mime type
# message/rfc822. Such messages are generated by, for example,
# Groupwise when forwarding unadorned messages. (Issue 7970.) So
# in that case we just emit the string body.
payload = msg._payload
if isinstance(payload, list):
g.flatten(msg.get_payload(0), unixfrom=False, linesep=self._NL)
payload = s.getvalue()
else:
payload = self._encode(payload)
self._fp.write(payload)
# This used to be a module level function; we use a classmethod for this
# and _compile_re so we can continue to provide the module level function
# for backward compatibility by doing
# _make_boudary = Generator._make_boundary
# at the end of the module. It *is* internal, so we could drop that...
@classmethod
def _make_boundary(cls, text=None):
# Craft a random boundary. If text is given, ensure that the chosen
# boundary doesn't appear in the text.
token = random.randrange(sys.maxsize)
boundary = ('=' * 15) + (_fmt % token) + '=='
if text is None:
return boundary
b = boundary
counter = 0
while True:
cre = cls._compile_re('^--' + re.escape(b) + '(--)?$', re.MULTILINE)
if not cre.search(text):
break
b = boundary + '.' + str(counter)
counter += 1
return b
@classmethod
def _compile_re(cls, s, flags):
return re.compile(s, flags)
class BytesGenerator(Generator):
"""Generates a bytes version of a Message object tree.
Functionally identical to the base Generator except that the output is
bytes and not string. When surrogates were used in the input to encode
bytes, these are decoded back to bytes for output. If the policy has
cte_type set to 7bit, then the message is transformed such that the
non-ASCII bytes are properly content transfer encoded, using the charset
unknown-8bit.
The outfp object must accept bytes in its write method.
"""
# Bytes versions of this constant for use in manipulating data from
# the BytesIO buffer.
_encoded_EMPTY = b''
def write(self, s):
self._fp.write(str(s).encode('ascii', 'surrogateescape'))
def _new_buffer(self):
return BytesIO()
def _encode(self, s):
return s.encode('ascii')
def _write_headers(self, msg):
# This is almost the same as the string version, except for handling
# strings with 8bit bytes.
for h, v in msg.raw_items():
self._fp.write(self.policy.fold_binary(h, v))
# A blank line always separates headers from body
self.write(self._NL)
def _handle_text(self, msg):
# If the string has surrogates the original source was bytes, so
# just write it back out.
if msg._payload is None:
return
if _has_surrogates(msg._payload) and not self.policy.cte_type=='7bit':
if self._mangle_from_:
msg._payload = fcre.sub(">From ", msg._payload)
self._write_lines(msg._payload)
else:
super(BytesGenerator,self)._handle_text(msg)
# Default body handler
_writeBody = _handle_text
@classmethod
def _compile_re(cls, s, flags):
return re.compile(s.encode('ascii'), flags)
_FMT = '[Non-text (%(type)s) part of message omitted, filename %(filename)s]'
class DecodedGenerator(Generator):
"""Generates a text representation of a message.
Like the Generator base class, except that non-text parts are substituted
with a format string representing the part.
"""
def __init__(self, outfp, mangle_from_=True, maxheaderlen=78, fmt=None):
"""Like Generator.__init__() except that an additional optional
argument is allowed.
Walks through all subparts of a message. If the subpart is of main
type `text', then it prints the decoded payload of the subpart.
Otherwise, fmt is a format string that is used instead of the message
payload. fmt is expanded with the following keywords (in
%(keyword)s format):
type : Full MIME type of the non-text part
maintype : Main MIME type of the non-text part
subtype : Sub-MIME type of the non-text part
filename : Filename of the non-text part
description: Description associated with the non-text part
encoding : Content transfer encoding of the non-text part
The default value for fmt is None, meaning
[Non-text (%(type)s) part of message omitted, filename %(filename)s]
"""
Generator.__init__(self, outfp, mangle_from_, maxheaderlen)
if fmt is None:
self._fmt = _FMT
else:
self._fmt = fmt
def _dispatch(self, msg):
for part in msg.walk():
maintype = part.get_content_maintype()
if maintype == 'text':
print(part.get_payload(decode=False), file=self)
elif maintype == 'multipart':
# Just skip this
pass
else:
print(self._fmt % {
'type' : part.get_content_type(),
'maintype' : part.get_content_maintype(),
'subtype' : part.get_content_subtype(),
'filename' : part.get_filename('[no filename]'),
'description': part.get('Content-Description',
'[no description]'),
'encoding' : part.get('Content-Transfer-Encoding',
'[no encoding]'),
}, file=self)
# Helper used by Generator._make_boundary
_width = len(repr(sys.maxsize-1))
_fmt = '%%0%dd' % _width
# Backward compatibility
_make_boundary = Generator._make_boundary

View file

@ -0,0 +1,581 @@
# Copyright (C) 2002-2007 Python Software Foundation
# Author: Ben Gertzfield, Barry Warsaw
# Contact: email-sig@python.org
"""Header encoding and decoding functionality."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import bytes, range, str, super, zip
__all__ = [
'Header',
'decode_header',
'make_header',
]
import re
import binascii
from future.backports import email
from future.backports.email import base64mime
from future.backports.email.errors import HeaderParseError
import future.backports.email.charset as _charset
# Helpers
from future.backports.email.quoprimime import _max_append, header_decode
Charset = _charset.Charset
NL = '\n'
SPACE = ' '
BSPACE = b' '
SPACE8 = ' ' * 8
EMPTYSTRING = ''
MAXLINELEN = 78
FWS = ' \t'
USASCII = Charset('us-ascii')
UTF8 = Charset('utf-8')
# Match encoded-word strings in the form =?charset?q?Hello_World?=
ecre = re.compile(r'''
=\? # literal =?
(?P<charset>[^?]*?) # non-greedy up to the next ? is the charset
\? # literal ?
(?P<encoding>[qb]) # either a "q" or a "b", case insensitive
\? # literal ?
(?P<encoded>.*?) # non-greedy up to the next ?= is the encoded string
\?= # literal ?=
''', re.VERBOSE | re.IGNORECASE | re.MULTILINE)
# Field name regexp, including trailing colon, but not separating whitespace,
# according to RFC 2822. Character range is from tilde to exclamation mark.
# For use with .match()
fcre = re.compile(r'[\041-\176]+:$')
# Find a header embedded in a putative header value. Used to check for
# header injection attack.
_embeded_header = re.compile(r'\n[^ \t]+:')
def decode_header(header):
"""Decode a message header value without converting charset.
Returns a list of (string, charset) pairs containing each of the decoded
parts of the header. Charset is None for non-encoded parts of the header,
otherwise a lower-case string containing the name of the character set
specified in the encoded string.
header may be a string that may or may not contain RFC2047 encoded words,
or it may be a Header object.
An email.errors.HeaderParseError may be raised when certain decoding error
occurs (e.g. a base64 decoding exception).
"""
# If it is a Header object, we can just return the encoded chunks.
if hasattr(header, '_chunks'):
return [(_charset._encode(string, str(charset)), str(charset))
for string, charset in header._chunks]
# If no encoding, just return the header with no charset.
if not ecre.search(header):
return [(header, None)]
# First step is to parse all the encoded parts into triplets of the form
# (encoded_string, encoding, charset). For unencoded strings, the last
# two parts will be None.
words = []
for line in header.splitlines():
parts = ecre.split(line)
first = True
while parts:
unencoded = parts.pop(0)
if first:
unencoded = unencoded.lstrip()
first = False
if unencoded:
words.append((unencoded, None, None))
if parts:
charset = parts.pop(0).lower()
encoding = parts.pop(0).lower()
encoded = parts.pop(0)
words.append((encoded, encoding, charset))
# Now loop over words and remove words that consist of whitespace
# between two encoded strings.
import sys
droplist = []
for n, w in enumerate(words):
if n>1 and w[1] and words[n-2][1] and words[n-1][0].isspace():
droplist.append(n-1)
for d in reversed(droplist):
del words[d]
# The next step is to decode each encoded word by applying the reverse
# base64 or quopri transformation. decoded_words is now a list of the
# form (decoded_word, charset).
decoded_words = []
for encoded_string, encoding, charset in words:
if encoding is None:
# This is an unencoded word.
decoded_words.append((encoded_string, charset))
elif encoding == 'q':
word = header_decode(encoded_string)
decoded_words.append((word, charset))
elif encoding == 'b':
paderr = len(encoded_string) % 4 # Postel's law: add missing padding
if paderr:
encoded_string += '==='[:4 - paderr]
try:
word = base64mime.decode(encoded_string)
except binascii.Error:
raise HeaderParseError('Base64 decoding error')
else:
decoded_words.append((word, charset))
else:
raise AssertionError('Unexpected encoding: ' + encoding)
# Now convert all words to bytes and collapse consecutive runs of
# similarly encoded words.
collapsed = []
last_word = last_charset = None
for word, charset in decoded_words:
if isinstance(word, str):
word = bytes(word, 'raw-unicode-escape')
if last_word is None:
last_word = word
last_charset = charset
elif charset != last_charset:
collapsed.append((last_word, last_charset))
last_word = word
last_charset = charset
elif last_charset is None:
last_word += BSPACE + word
else:
last_word += word
collapsed.append((last_word, last_charset))
return collapsed
def make_header(decoded_seq, maxlinelen=None, header_name=None,
continuation_ws=' '):
"""Create a Header from a sequence of pairs as returned by decode_header()
decode_header() takes a header value string and returns a sequence of
pairs of the format (decoded_string, charset) where charset is the string
name of the character set.
This function takes one of those sequence of pairs and returns a Header
instance. Optional maxlinelen, header_name, and continuation_ws are as in
the Header constructor.
"""
h = Header(maxlinelen=maxlinelen, header_name=header_name,
continuation_ws=continuation_ws)
for s, charset in decoded_seq:
# None means us-ascii but we can simply pass it on to h.append()
if charset is not None and not isinstance(charset, Charset):
charset = Charset(charset)
h.append(s, charset)
return h
class Header(object):
def __init__(self, s=None, charset=None,
maxlinelen=None, header_name=None,
continuation_ws=' ', errors='strict'):
"""Create a MIME-compliant header that can contain many character sets.
Optional s is the initial header value. If None, the initial header
value is not set. You can later append to the header with .append()
method calls. s may be a byte string or a Unicode string, but see the
.append() documentation for semantics.
Optional charset serves two purposes: it has the same meaning as the
charset argument to the .append() method. It also sets the default
character set for all subsequent .append() calls that omit the charset
argument. If charset is not provided in the constructor, the us-ascii
charset is used both as s's initial charset and as the default for
subsequent .append() calls.
The maximum line length can be specified explicitly via maxlinelen. For
splitting the first line to a shorter value (to account for the field
header which isn't included in s, e.g. `Subject') pass in the name of
the field in header_name. The default maxlinelen is 78 as recommended
by RFC 2822.
continuation_ws must be RFC 2822 compliant folding whitespace (usually
either a space or a hard tab) which will be prepended to continuation
lines.
errors is passed through to the .append() call.
"""
if charset is None:
charset = USASCII
elif not isinstance(charset, Charset):
charset = Charset(charset)
self._charset = charset
self._continuation_ws = continuation_ws
self._chunks = []
if s is not None:
self.append(s, charset, errors)
if maxlinelen is None:
maxlinelen = MAXLINELEN
self._maxlinelen = maxlinelen
if header_name is None:
self._headerlen = 0
else:
# Take the separating colon and space into account.
self._headerlen = len(header_name) + 2
def __str__(self):
"""Return the string value of the header."""
self._normalize()
uchunks = []
lastcs = None
lastspace = None
for string, charset in self._chunks:
# We must preserve spaces between encoded and non-encoded word
# boundaries, which means for us we need to add a space when we go
# from a charset to None/us-ascii, or from None/us-ascii to a
# charset. Only do this for the second and subsequent chunks.
# Don't add a space if the None/us-ascii string already has
# a space (trailing or leading depending on transition)
nextcs = charset
if nextcs == _charset.UNKNOWN8BIT:
original_bytes = string.encode('ascii', 'surrogateescape')
string = original_bytes.decode('ascii', 'replace')
if uchunks:
hasspace = string and self._nonctext(string[0])
if lastcs not in (None, 'us-ascii'):
if nextcs in (None, 'us-ascii') and not hasspace:
uchunks.append(SPACE)
nextcs = None
elif nextcs not in (None, 'us-ascii') and not lastspace:
uchunks.append(SPACE)
lastspace = string and self._nonctext(string[-1])
lastcs = nextcs
uchunks.append(string)
return EMPTYSTRING.join(uchunks)
# Rich comparison operators for equality only. BAW: does it make sense to
# have or explicitly disable <, <=, >, >= operators?
def __eq__(self, other):
# other may be a Header or a string. Both are fine so coerce
# ourselves to a unicode (of the unencoded header value), swap the
# args and do another comparison.
return other == str(self)
def __ne__(self, other):
return not self == other
def append(self, s, charset=None, errors='strict'):
"""Append a string to the MIME header.
Optional charset, if given, should be a Charset instance or the name
of a character set (which will be converted to a Charset instance). A
value of None (the default) means that the charset given in the
constructor is used.
s may be a byte string or a Unicode string. If it is a byte string
(i.e. isinstance(s, str) is false), then charset is the encoding of
that byte string, and a UnicodeError will be raised if the string
cannot be decoded with that charset. If s is a Unicode string, then
charset is a hint specifying the character set of the characters in
the string. In either case, when producing an RFC 2822 compliant
header using RFC 2047 rules, the string will be encoded using the
output codec of the charset. If the string cannot be encoded to the
output codec, a UnicodeError will be raised.
Optional `errors' is passed as the errors argument to the decode
call if s is a byte string.
"""
if charset is None:
charset = self._charset
elif not isinstance(charset, Charset):
charset = Charset(charset)
if not isinstance(s, str):
input_charset = charset.input_codec or 'us-ascii'
if input_charset == _charset.UNKNOWN8BIT:
s = s.decode('us-ascii', 'surrogateescape')
else:
s = s.decode(input_charset, errors)
# Ensure that the bytes we're storing can be decoded to the output
# character set, otherwise an early error is raised.
output_charset = charset.output_codec or 'us-ascii'
if output_charset != _charset.UNKNOWN8BIT:
try:
s.encode(output_charset, errors)
except UnicodeEncodeError:
if output_charset!='us-ascii':
raise
charset = UTF8
self._chunks.append((s, charset))
def _nonctext(self, s):
"""True if string s is not a ctext character of RFC822.
"""
return s.isspace() or s in ('(', ')', '\\')
def encode(self, splitchars=';, \t', maxlinelen=None, linesep='\n'):
r"""Encode a message header into an RFC-compliant format.
There are many issues involved in converting a given string for use in
an email header. Only certain character sets are readable in most
email clients, and as header strings can only contain a subset of
7-bit ASCII, care must be taken to properly convert and encode (with
Base64 or quoted-printable) header strings. In addition, there is a
75-character length limit on any given encoded header field, so
line-wrapping must be performed, even with double-byte character sets.
Optional maxlinelen specifies the maximum length of each generated
line, exclusive of the linesep string. Individual lines may be longer
than maxlinelen if a folding point cannot be found. The first line
will be shorter by the length of the header name plus ": " if a header
name was specified at Header construction time. The default value for
maxlinelen is determined at header construction time.
Optional splitchars is a string containing characters which should be
given extra weight by the splitting algorithm during normal header
wrapping. This is in very rough support of RFC 2822's `higher level
syntactic breaks': split points preceded by a splitchar are preferred
during line splitting, with the characters preferred in the order in
which they appear in the string. Space and tab may be included in the
string to indicate whether preference should be given to one over the
other as a split point when other split chars do not appear in the line
being split. Splitchars does not affect RFC 2047 encoded lines.
Optional linesep is a string to be used to separate the lines of
the value. The default value is the most useful for typical
Python applications, but it can be set to \r\n to produce RFC-compliant
line separators when needed.
"""
self._normalize()
if maxlinelen is None:
maxlinelen = self._maxlinelen
# A maxlinelen of 0 means don't wrap. For all practical purposes,
# choosing a huge number here accomplishes that and makes the
# _ValueFormatter algorithm much simpler.
if maxlinelen == 0:
maxlinelen = 1000000
formatter = _ValueFormatter(self._headerlen, maxlinelen,
self._continuation_ws, splitchars)
lastcs = None
hasspace = lastspace = None
for string, charset in self._chunks:
if hasspace is not None:
hasspace = string and self._nonctext(string[0])
import sys
if lastcs not in (None, 'us-ascii'):
if not hasspace or charset not in (None, 'us-ascii'):
formatter.add_transition()
elif charset not in (None, 'us-ascii') and not lastspace:
formatter.add_transition()
lastspace = string and self._nonctext(string[-1])
lastcs = charset
hasspace = False
lines = string.splitlines()
if lines:
formatter.feed('', lines[0], charset)
else:
formatter.feed('', '', charset)
for line in lines[1:]:
formatter.newline()
if charset.header_encoding is not None:
formatter.feed(self._continuation_ws, ' ' + line.lstrip(),
charset)
else:
sline = line.lstrip()
fws = line[:len(line)-len(sline)]
formatter.feed(fws, sline, charset)
if len(lines) > 1:
formatter.newline()
if self._chunks:
formatter.add_transition()
value = formatter._str(linesep)
if _embeded_header.search(value):
raise HeaderParseError("header value appears to contain "
"an embedded header: {!r}".format(value))
return value
def _normalize(self):
# Step 1: Normalize the chunks so that all runs of identical charsets
# get collapsed into a single unicode string.
chunks = []
last_charset = None
last_chunk = []
for string, charset in self._chunks:
if charset == last_charset:
last_chunk.append(string)
else:
if last_charset is not None:
chunks.append((SPACE.join(last_chunk), last_charset))
last_chunk = [string]
last_charset = charset
if last_chunk:
chunks.append((SPACE.join(last_chunk), last_charset))
self._chunks = chunks
class _ValueFormatter(object):
def __init__(self, headerlen, maxlen, continuation_ws, splitchars):
self._maxlen = maxlen
self._continuation_ws = continuation_ws
self._continuation_ws_len = len(continuation_ws)
self._splitchars = splitchars
self._lines = []
self._current_line = _Accumulator(headerlen)
def _str(self, linesep):
self.newline()
return linesep.join(self._lines)
def __str__(self):
return self._str(NL)
def newline(self):
end_of_line = self._current_line.pop()
if end_of_line != (' ', ''):
self._current_line.push(*end_of_line)
if len(self._current_line) > 0:
if self._current_line.is_onlyws():
self._lines[-1] += str(self._current_line)
else:
self._lines.append(str(self._current_line))
self._current_line.reset()
def add_transition(self):
self._current_line.push(' ', '')
def feed(self, fws, string, charset):
# If the charset has no header encoding (i.e. it is an ASCII encoding)
# then we must split the header at the "highest level syntactic break"
# possible. Note that we don't have a lot of smarts about field
# syntax; we just try to break on semi-colons, then commas, then
# whitespace. Eventually, this should be pluggable.
if charset.header_encoding is None:
self._ascii_split(fws, string, self._splitchars)
return
# Otherwise, we're doing either a Base64 or a quoted-printable
# encoding which means we don't need to split the line on syntactic
# breaks. We can basically just find enough characters to fit on the
# current line, minus the RFC 2047 chrome. What makes this trickier
# though is that we have to split at octet boundaries, not character
# boundaries but it's only safe to split at character boundaries so at
# best we can only get close.
encoded_lines = charset.header_encode_lines(string, self._maxlengths())
# The first element extends the current line, but if it's None then
# nothing more fit on the current line so start a new line.
try:
first_line = encoded_lines.pop(0)
except IndexError:
# There are no encoded lines, so we're done.
return
if first_line is not None:
self._append_chunk(fws, first_line)
try:
last_line = encoded_lines.pop()
except IndexError:
# There was only one line.
return
self.newline()
self._current_line.push(self._continuation_ws, last_line)
# Everything else are full lines in themselves.
for line in encoded_lines:
self._lines.append(self._continuation_ws + line)
def _maxlengths(self):
# The first line's length.
yield self._maxlen - len(self._current_line)
while True:
yield self._maxlen - self._continuation_ws_len
def _ascii_split(self, fws, string, splitchars):
# The RFC 2822 header folding algorithm is simple in principle but
# complex in practice. Lines may be folded any place where "folding
# white space" appears by inserting a linesep character in front of the
# FWS. The complication is that not all spaces or tabs qualify as FWS,
# and we are also supposed to prefer to break at "higher level
# syntactic breaks". We can't do either of these without intimate
# knowledge of the structure of structured headers, which we don't have
# here. So the best we can do here is prefer to break at the specified
# splitchars, and hope that we don't choose any spaces or tabs that
# aren't legal FWS. (This is at least better than the old algorithm,
# where we would sometimes *introduce* FWS after a splitchar, or the
# algorithm before that, where we would turn all white space runs into
# single spaces or tabs.)
parts = re.split("(["+FWS+"]+)", fws+string)
if parts[0]:
parts[:0] = ['']
else:
parts.pop(0)
for fws, part in zip(*[iter(parts)]*2):
self._append_chunk(fws, part)
def _append_chunk(self, fws, string):
self._current_line.push(fws, string)
if len(self._current_line) > self._maxlen:
# Find the best split point, working backward from the end.
# There might be none, on a long first line.
for ch in self._splitchars:
for i in range(self._current_line.part_count()-1, 0, -1):
if ch.isspace():
fws = self._current_line[i][0]
if fws and fws[0]==ch:
break
prevpart = self._current_line[i-1][1]
if prevpart and prevpart[-1]==ch:
break
else:
continue
break
else:
fws, part = self._current_line.pop()
if self._current_line._initial_size > 0:
# There will be a header, so leave it on a line by itself.
self.newline()
if not fws:
# We don't use continuation_ws here because the whitespace
# after a header should always be a space.
fws = ' '
self._current_line.push(fws, part)
return
remainder = self._current_line.pop_from(i)
self._lines.append(str(self._current_line))
self._current_line.reset(remainder)
class _Accumulator(list):
def __init__(self, initial_size=0):
self._initial_size = initial_size
super().__init__()
def push(self, fws, string):
self.append((fws, string))
def pop_from(self, i=0):
popped = self[i:]
self[i:] = []
return popped
def pop(self):
if self.part_count()==0:
return ('', '')
return super().pop()
def __len__(self):
return sum((len(fws)+len(part) for fws, part in self),
self._initial_size)
def __str__(self):
return EMPTYSTRING.join((EMPTYSTRING.join((fws, part))
for fws, part in self))
def reset(self, startval=None):
if startval is None:
startval = []
self[:] = startval
self._initial_size = 0
def is_onlyws(self):
return self._initial_size==0 and (not self or str(self).isspace())
def part_count(self):
return super().__len__()

View file

@ -0,0 +1,592 @@
"""Representing and manipulating email headers via custom objects.
This module provides an implementation of the HeaderRegistry API.
The implementation is designed to flexibly follow RFC5322 rules.
Eventually HeaderRegistry will be a public API, but it isn't yet,
and will probably change some before that happens.
"""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import super
from future.builtins import str
from future.utils import text_to_native_str
from future.backports.email import utils
from future.backports.email import errors
from future.backports.email import _header_value_parser as parser
class Address(object):
def __init__(self, display_name='', username='', domain='', addr_spec=None):
"""Create an object represeting a full email address.
An address can have a 'display_name', a 'username', and a 'domain'. In
addition to specifying the username and domain separately, they may be
specified together by using the addr_spec keyword *instead of* the
username and domain keywords. If an addr_spec string is specified it
must be properly quoted according to RFC 5322 rules; an error will be
raised if it is not.
An Address object has display_name, username, domain, and addr_spec
attributes, all of which are read-only. The addr_spec and the string
value of the object are both quoted according to RFC5322 rules, but
without any Content Transfer Encoding.
"""
# This clause with its potential 'raise' may only happen when an
# application program creates an Address object using an addr_spec
# keyword. The email library code itself must always supply username
# and domain.
if addr_spec is not None:
if username or domain:
raise TypeError("addrspec specified when username and/or "
"domain also specified")
a_s, rest = parser.get_addr_spec(addr_spec)
if rest:
raise ValueError("Invalid addr_spec; only '{}' "
"could be parsed from '{}'".format(
a_s, addr_spec))
if a_s.all_defects:
raise a_s.all_defects[0]
username = a_s.local_part
domain = a_s.domain
self._display_name = display_name
self._username = username
self._domain = domain
@property
def display_name(self):
return self._display_name
@property
def username(self):
return self._username
@property
def domain(self):
return self._domain
@property
def addr_spec(self):
"""The addr_spec (username@domain) portion of the address, quoted
according to RFC 5322 rules, but with no Content Transfer Encoding.
"""
nameset = set(self.username)
if len(nameset) > len(nameset-parser.DOT_ATOM_ENDS):
lp = parser.quote_string(self.username)
else:
lp = self.username
if self.domain:
return lp + '@' + self.domain
if not lp:
return '<>'
return lp
def __repr__(self):
return "Address(display_name={!r}, username={!r}, domain={!r})".format(
self.display_name, self.username, self.domain)
def __str__(self):
nameset = set(self.display_name)
if len(nameset) > len(nameset-parser.SPECIALS):
disp = parser.quote_string(self.display_name)
else:
disp = self.display_name
if disp:
addr_spec = '' if self.addr_spec=='<>' else self.addr_spec
return "{} <{}>".format(disp, addr_spec)
return self.addr_spec
def __eq__(self, other):
if type(other) != type(self):
return False
return (self.display_name == other.display_name and
self.username == other.username and
self.domain == other.domain)
class Group(object):
def __init__(self, display_name=None, addresses=None):
"""Create an object representing an address group.
An address group consists of a display_name followed by colon and an
list of addresses (see Address) terminated by a semi-colon. The Group
is created by specifying a display_name and a possibly empty list of
Address objects. A Group can also be used to represent a single
address that is not in a group, which is convenient when manipulating
lists that are a combination of Groups and individual Addresses. In
this case the display_name should be set to None. In particular, the
string representation of a Group whose display_name is None is the same
as the Address object, if there is one and only one Address object in
the addresses list.
"""
self._display_name = display_name
self._addresses = tuple(addresses) if addresses else tuple()
@property
def display_name(self):
return self._display_name
@property
def addresses(self):
return self._addresses
def __repr__(self):
return "Group(display_name={!r}, addresses={!r}".format(
self.display_name, self.addresses)
def __str__(self):
if self.display_name is None and len(self.addresses)==1:
return str(self.addresses[0])
disp = self.display_name
if disp is not None:
nameset = set(disp)
if len(nameset) > len(nameset-parser.SPECIALS):
disp = parser.quote_string(disp)
adrstr = ", ".join(str(x) for x in self.addresses)
adrstr = ' ' + adrstr if adrstr else adrstr
return "{}:{};".format(disp, adrstr)
def __eq__(self, other):
if type(other) != type(self):
return False
return (self.display_name == other.display_name and
self.addresses == other.addresses)
# Header Classes #
class BaseHeader(str):
"""Base class for message headers.
Implements generic behavior and provides tools for subclasses.
A subclass must define a classmethod named 'parse' that takes an unfolded
value string and a dictionary as its arguments. The dictionary will
contain one key, 'defects', initialized to an empty list. After the call
the dictionary must contain two additional keys: parse_tree, set to the
parse tree obtained from parsing the header, and 'decoded', set to the
string value of the idealized representation of the data from the value.
(That is, encoded words are decoded, and values that have canonical
representations are so represented.)
The defects key is intended to collect parsing defects, which the message
parser will subsequently dispose of as appropriate. The parser should not,
insofar as practical, raise any errors. Defects should be added to the
list instead. The standard header parsers register defects for RFC
compliance issues, for obsolete RFC syntax, and for unrecoverable parsing
errors.
The parse method may add additional keys to the dictionary. In this case
the subclass must define an 'init' method, which will be passed the
dictionary as its keyword arguments. The method should use (usually by
setting them as the value of similarly named attributes) and remove all the
extra keys added by its parse method, and then use super to call its parent
class with the remaining arguments and keywords.
The subclass should also make sure that a 'max_count' attribute is defined
that is either None or 1. XXX: need to better define this API.
"""
def __new__(cls, name, value):
kwds = {'defects': []}
cls.parse(value, kwds)
if utils._has_surrogates(kwds['decoded']):
kwds['decoded'] = utils._sanitize(kwds['decoded'])
self = str.__new__(cls, kwds['decoded'])
# del kwds['decoded']
self.init(name, **kwds)
return self
def init(self, name, **_3to2kwargs):
defects = _3to2kwargs['defects']; del _3to2kwargs['defects']
parse_tree = _3to2kwargs['parse_tree']; del _3to2kwargs['parse_tree']
self._name = name
self._parse_tree = parse_tree
self._defects = defects
@property
def name(self):
return self._name
@property
def defects(self):
return tuple(self._defects)
def __reduce__(self):
return (
_reconstruct_header,
(
self.__class__.__name__,
self.__class__.__bases__,
str(self),
),
self.__dict__)
@classmethod
def _reconstruct(cls, value):
return str.__new__(cls, value)
def fold(self, **_3to2kwargs):
policy = _3to2kwargs['policy']; del _3to2kwargs['policy']
"""Fold header according to policy.
The parsed representation of the header is folded according to
RFC5322 rules, as modified by the policy. If the parse tree
contains surrogateescaped bytes, the bytes are CTE encoded using
the charset 'unknown-8bit".
Any non-ASCII characters in the parse tree are CTE encoded using
charset utf-8. XXX: make this a policy setting.
The returned value is an ASCII-only string possibly containing linesep
characters, and ending with a linesep character. The string includes
the header name and the ': ' separator.
"""
# At some point we need to only put fws here if it was in the source.
header = parser.Header([
parser.HeaderLabel([
parser.ValueTerminal(self.name, 'header-name'),
parser.ValueTerminal(':', 'header-sep')]),
parser.CFWSList([parser.WhiteSpaceTerminal(' ', 'fws')]),
self._parse_tree])
return header.fold(policy=policy)
def _reconstruct_header(cls_name, bases, value):
return type(text_to_native_str(cls_name), bases, {})._reconstruct(value)
class UnstructuredHeader(object):
max_count = None
value_parser = staticmethod(parser.get_unstructured)
@classmethod
def parse(cls, value, kwds):
kwds['parse_tree'] = cls.value_parser(value)
kwds['decoded'] = str(kwds['parse_tree'])
class UniqueUnstructuredHeader(UnstructuredHeader):
max_count = 1
class DateHeader(object):
"""Header whose value consists of a single timestamp.
Provides an additional attribute, datetime, which is either an aware
datetime using a timezone, or a naive datetime if the timezone
in the input string is -0000. Also accepts a datetime as input.
The 'value' attribute is the normalized form of the timestamp,
which means it is the output of format_datetime on the datetime.
"""
max_count = None
# This is used only for folding, not for creating 'decoded'.
value_parser = staticmethod(parser.get_unstructured)
@classmethod
def parse(cls, value, kwds):
if not value:
kwds['defects'].append(errors.HeaderMissingRequiredValue())
kwds['datetime'] = None
kwds['decoded'] = ''
kwds['parse_tree'] = parser.TokenList()
return
if isinstance(value, str):
value = utils.parsedate_to_datetime(value)
kwds['datetime'] = value
kwds['decoded'] = utils.format_datetime(kwds['datetime'])
kwds['parse_tree'] = cls.value_parser(kwds['decoded'])
def init(self, *args, **kw):
self._datetime = kw.pop('datetime')
super().init(*args, **kw)
@property
def datetime(self):
return self._datetime
class UniqueDateHeader(DateHeader):
max_count = 1
class AddressHeader(object):
max_count = None
@staticmethod
def value_parser(value):
address_list, value = parser.get_address_list(value)
assert not value, 'this should not happen'
return address_list
@classmethod
def parse(cls, value, kwds):
if isinstance(value, str):
# We are translating here from the RFC language (address/mailbox)
# to our API language (group/address).
kwds['parse_tree'] = address_list = cls.value_parser(value)
groups = []
for addr in address_list.addresses:
groups.append(Group(addr.display_name,
[Address(mb.display_name or '',
mb.local_part or '',
mb.domain or '')
for mb in addr.all_mailboxes]))
defects = list(address_list.all_defects)
else:
# Assume it is Address/Group stuff
if not hasattr(value, '__iter__'):
value = [value]
groups = [Group(None, [item]) if not hasattr(item, 'addresses')
else item
for item in value]
defects = []
kwds['groups'] = groups
kwds['defects'] = defects
kwds['decoded'] = ', '.join([str(item) for item in groups])
if 'parse_tree' not in kwds:
kwds['parse_tree'] = cls.value_parser(kwds['decoded'])
def init(self, *args, **kw):
self._groups = tuple(kw.pop('groups'))
self._addresses = None
super().init(*args, **kw)
@property
def groups(self):
return self._groups
@property
def addresses(self):
if self._addresses is None:
self._addresses = tuple([address for group in self._groups
for address in group.addresses])
return self._addresses
class UniqueAddressHeader(AddressHeader):
max_count = 1
class SingleAddressHeader(AddressHeader):
@property
def address(self):
if len(self.addresses)!=1:
raise ValueError(("value of single address header {} is not "
"a single address").format(self.name))
return self.addresses[0]
class UniqueSingleAddressHeader(SingleAddressHeader):
max_count = 1
class MIMEVersionHeader(object):
max_count = 1
value_parser = staticmethod(parser.parse_mime_version)
@classmethod
def parse(cls, value, kwds):
kwds['parse_tree'] = parse_tree = cls.value_parser(value)
kwds['decoded'] = str(parse_tree)
kwds['defects'].extend(parse_tree.all_defects)
kwds['major'] = None if parse_tree.minor is None else parse_tree.major
kwds['minor'] = parse_tree.minor
if parse_tree.minor is not None:
kwds['version'] = '{}.{}'.format(kwds['major'], kwds['minor'])
else:
kwds['version'] = None
def init(self, *args, **kw):
self._version = kw.pop('version')
self._major = kw.pop('major')
self._minor = kw.pop('minor')
super().init(*args, **kw)
@property
def major(self):
return self._major
@property
def minor(self):
return self._minor
@property
def version(self):
return self._version
class ParameterizedMIMEHeader(object):
# Mixin that handles the params dict. Must be subclassed and
# a property value_parser for the specific header provided.
max_count = 1
@classmethod
def parse(cls, value, kwds):
kwds['parse_tree'] = parse_tree = cls.value_parser(value)
kwds['decoded'] = str(parse_tree)
kwds['defects'].extend(parse_tree.all_defects)
if parse_tree.params is None:
kwds['params'] = {}
else:
# The MIME RFCs specify that parameter ordering is arbitrary.
kwds['params'] = dict((utils._sanitize(name).lower(),
utils._sanitize(value))
for name, value in parse_tree.params)
def init(self, *args, **kw):
self._params = kw.pop('params')
super().init(*args, **kw)
@property
def params(self):
return self._params.copy()
class ContentTypeHeader(ParameterizedMIMEHeader):
value_parser = staticmethod(parser.parse_content_type_header)
def init(self, *args, **kw):
super().init(*args, **kw)
self._maintype = utils._sanitize(self._parse_tree.maintype)
self._subtype = utils._sanitize(self._parse_tree.subtype)
@property
def maintype(self):
return self._maintype
@property
def subtype(self):
return self._subtype
@property
def content_type(self):
return self.maintype + '/' + self.subtype
class ContentDispositionHeader(ParameterizedMIMEHeader):
value_parser = staticmethod(parser.parse_content_disposition_header)
def init(self, *args, **kw):
super().init(*args, **kw)
cd = self._parse_tree.content_disposition
self._content_disposition = cd if cd is None else utils._sanitize(cd)
@property
def content_disposition(self):
return self._content_disposition
class ContentTransferEncodingHeader(object):
max_count = 1
value_parser = staticmethod(parser.parse_content_transfer_encoding_header)
@classmethod
def parse(cls, value, kwds):
kwds['parse_tree'] = parse_tree = cls.value_parser(value)
kwds['decoded'] = str(parse_tree)
kwds['defects'].extend(parse_tree.all_defects)
def init(self, *args, **kw):
super().init(*args, **kw)
self._cte = utils._sanitize(self._parse_tree.cte)
@property
def cte(self):
return self._cte
# The header factory #
_default_header_map = {
'subject': UniqueUnstructuredHeader,
'date': UniqueDateHeader,
'resent-date': DateHeader,
'orig-date': UniqueDateHeader,
'sender': UniqueSingleAddressHeader,
'resent-sender': SingleAddressHeader,
'to': UniqueAddressHeader,
'resent-to': AddressHeader,
'cc': UniqueAddressHeader,
'resent-cc': AddressHeader,
'bcc': UniqueAddressHeader,
'resent-bcc': AddressHeader,
'from': UniqueAddressHeader,
'resent-from': AddressHeader,
'reply-to': UniqueAddressHeader,
'mime-version': MIMEVersionHeader,
'content-type': ContentTypeHeader,
'content-disposition': ContentDispositionHeader,
'content-transfer-encoding': ContentTransferEncodingHeader,
}
class HeaderRegistry(object):
"""A header_factory and header registry."""
def __init__(self, base_class=BaseHeader, default_class=UnstructuredHeader,
use_default_map=True):
"""Create a header_factory that works with the Policy API.
base_class is the class that will be the last class in the created
header class's __bases__ list. default_class is the class that will be
used if "name" (see __call__) does not appear in the registry.
use_default_map controls whether or not the default mapping of names to
specialized classes is copied in to the registry when the factory is
created. The default is True.
"""
self.registry = {}
self.base_class = base_class
self.default_class = default_class
if use_default_map:
self.registry.update(_default_header_map)
def map_to_type(self, name, cls):
"""Register cls as the specialized class for handling "name" headers.
"""
self.registry[name.lower()] = cls
def __getitem__(self, name):
cls = self.registry.get(name.lower(), self.default_class)
return type(text_to_native_str('_'+cls.__name__), (cls, self.base_class), {})
def __call__(self, name, value):
"""Create a header instance for header 'name' from 'value'.
Creates a header instance by creating a specialized class for parsing
and representing the specified header by combining the factory
base_class with a specialized class from the registry or the
default_class, and passing the name and value to the constructed
class's constructor.
"""
return self[name](name, value)

View file

@ -0,0 +1,74 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Various types of useful iterators and generators."""
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
__all__ = [
'body_line_iterator',
'typed_subpart_iterator',
'walk',
# Do not include _structure() since it's part of the debugging API.
]
import sys
from io import StringIO
# This function will become a method of the Message class
def walk(self):
"""Walk over the message tree, yielding each subpart.
The walk is performed in depth-first order. This method is a
generator.
"""
yield self
if self.is_multipart():
for subpart in self.get_payload():
for subsubpart in subpart.walk():
yield subsubpart
# These two functions are imported into the Iterators.py interface module.
def body_line_iterator(msg, decode=False):
"""Iterate over the parts, returning string payloads line-by-line.
Optional decode (default False) is passed through to .get_payload().
"""
for subpart in msg.walk():
payload = subpart.get_payload(decode=decode)
if isinstance(payload, str):
for line in StringIO(payload):
yield line
def typed_subpart_iterator(msg, maintype='text', subtype=None):
"""Iterate over the subparts with a given MIME type.
Use `maintype' as the main MIME type to match against; this defaults to
"text". Optional `subtype' is the MIME subtype to match against; if
omitted, only the main type is matched.
"""
for subpart in msg.walk():
if subpart.get_content_maintype() == maintype:
if subtype is None or subpart.get_content_subtype() == subtype:
yield subpart
def _structure(msg, fp=None, level=0, include_default=False):
"""A handy debugging aid"""
if fp is None:
fp = sys.stdout
tab = ' ' * (level * 4)
print(tab + msg.get_content_type(), end='', file=fp)
if include_default:
print(' [%s]' % msg.get_default_type(), file=fp)
else:
print(file=fp)
if msg.is_multipart():
for subpart in msg.get_payload():
_structure(subpart, fp, level+1, include_default)

View file

@ -0,0 +1,882 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2001-2007 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Basic message object for the email package object model."""
from __future__ import absolute_import, division, unicode_literals
from future.builtins import list, range, str, zip
__all__ = ['Message']
import re
import uu
import base64
import binascii
from io import BytesIO, StringIO
# Intrapackage imports
from future.utils import as_native_str
from future.backports.email import utils
from future.backports.email import errors
from future.backports.email._policybase import compat32
from future.backports.email import charset as _charset
from future.backports.email._encoded_words import decode_b
Charset = _charset.Charset
SEMISPACE = '; '
# Regular expression that matches `special' characters in parameters, the
# existence of which force quoting of the parameter value.
tspecials = re.compile(r'[ \(\)<>@,;:\\"/\[\]\?=]')
def _splitparam(param):
# Split header parameters. BAW: this may be too simple. It isn't
# strictly RFC 2045 (section 5.1) compliant, but it catches most headers
# found in the wild. We may eventually need a full fledged parser.
# RDM: we might have a Header here; for now just stringify it.
a, sep, b = str(param).partition(';')
if not sep:
return a.strip(), None
return a.strip(), b.strip()
def _formatparam(param, value=None, quote=True):
"""Convenience function to format and return a key=value pair.
This will quote the value if needed or if quote is true. If value is a
three tuple (charset, language, value), it will be encoded according
to RFC2231 rules. If it contains non-ascii characters it will likewise
be encoded according to RFC2231 rules, using the utf-8 charset and
a null language.
"""
if value is not None and len(value) > 0:
# A tuple is used for RFC 2231 encoded parameter values where items
# are (charset, language, value). charset is a string, not a Charset
# instance. RFC 2231 encoded values are never quoted, per RFC.
if isinstance(value, tuple):
# Encode as per RFC 2231
param += '*'
value = utils.encode_rfc2231(value[2], value[0], value[1])
return '%s=%s' % (param, value)
else:
try:
value.encode('ascii')
except UnicodeEncodeError:
param += '*'
value = utils.encode_rfc2231(value, 'utf-8', '')
return '%s=%s' % (param, value)
# BAW: Please check this. I think that if quote is set it should
# force quoting even if not necessary.
if quote or tspecials.search(value):
return '%s="%s"' % (param, utils.quote(value))
else:
return '%s=%s' % (param, value)
else:
return param
def _parseparam(s):
# RDM This might be a Header, so for now stringify it.
s = ';' + str(s)
plist = []
while s[:1] == ';':
s = s[1:]
end = s.find(';')
while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2:
end = s.find(';', end + 1)
if end < 0:
end = len(s)
f = s[:end]
if '=' in f:
i = f.index('=')
f = f[:i].strip().lower() + '=' + f[i+1:].strip()
plist.append(f.strip())
s = s[end:]
return plist
def _unquotevalue(value):
# This is different than utils.collapse_rfc2231_value() because it doesn't
# try to convert the value to a unicode. Message.get_param() and
# Message.get_params() are both currently defined to return the tuple in
# the face of RFC 2231 parameters.
if isinstance(value, tuple):
return value[0], value[1], utils.unquote(value[2])
else:
return utils.unquote(value)
class Message(object):
"""Basic message object.
A message object is defined as something that has a bunch of RFC 2822
headers and a payload. It may optionally have an envelope header
(a.k.a. Unix-From or From_ header). If the message is a container (i.e. a
multipart or a message/rfc822), then the payload is a list of Message
objects, otherwise it is a string.
Message objects implement part of the `mapping' interface, which assumes
there is exactly one occurrence of the header per message. Some headers
do in fact appear multiple times (e.g. Received) and for those headers,
you must use the explicit API to set or get all the headers. Not all of
the mapping methods are implemented.
"""
def __init__(self, policy=compat32):
self.policy = policy
self._headers = list()
self._unixfrom = None
self._payload = None
self._charset = None
# Defaults for multipart messages
self.preamble = self.epilogue = None
self.defects = []
# Default content type
self._default_type = 'text/plain'
@as_native_str(encoding='utf-8')
def __str__(self):
"""Return the entire formatted message as a string.
This includes the headers, body, and envelope header.
"""
return self.as_string()
def as_string(self, unixfrom=False, maxheaderlen=0):
"""Return the entire formatted message as a (unicode) string.
Optional `unixfrom' when True, means include the Unix From_ envelope
header.
This is a convenience method and may not generate the message exactly
as you intend. For more flexibility, use the flatten() method of a
Generator instance.
"""
from future.backports.email.generator import Generator
fp = StringIO()
g = Generator(fp, mangle_from_=False, maxheaderlen=maxheaderlen)
g.flatten(self, unixfrom=unixfrom)
return fp.getvalue()
def is_multipart(self):
"""Return True if the message consists of multiple parts."""
return isinstance(self._payload, list)
#
# Unix From_ line
#
def set_unixfrom(self, unixfrom):
self._unixfrom = unixfrom
def get_unixfrom(self):
return self._unixfrom
#
# Payload manipulation.
#
def attach(self, payload):
"""Add the given payload to the current payload.
The current payload will always be a list of objects after this method
is called. If you want to set the payload to a scalar object, use
set_payload() instead.
"""
if self._payload is None:
self._payload = [payload]
else:
self._payload.append(payload)
def get_payload(self, i=None, decode=False):
"""Return a reference to the payload.
The payload will either be a list object or a string. If you mutate
the list object, you modify the message's payload in place. Optional
i returns that index into the payload.
Optional decode is a flag indicating whether the payload should be
decoded or not, according to the Content-Transfer-Encoding header
(default is False).
When True and the message is not a multipart, the payload will be
decoded if this header's value is `quoted-printable' or `base64'. If
some other encoding is used, or the header is missing, or if the
payload has bogus data (i.e. bogus base64 or uuencoded data), the
payload is returned as-is.
If the message is a multipart and the decode flag is True, then None
is returned.
"""
# Here is the logic table for this code, based on the email5.0.0 code:
# i decode is_multipart result
# ------ ------ ------------ ------------------------------
# None True True None
# i True True None
# None False True _payload (a list)
# i False True _payload element i (a Message)
# i False False error (not a list)
# i True False error (not a list)
# None False False _payload
# None True False _payload decoded (bytes)
# Note that Barry planned to factor out the 'decode' case, but that
# isn't so easy now that we handle the 8 bit data, which needs to be
# converted in both the decode and non-decode path.
if self.is_multipart():
if decode:
return None
if i is None:
return self._payload
else:
return self._payload[i]
# For backward compatibility, Use isinstance and this error message
# instead of the more logical is_multipart test.
if i is not None and not isinstance(self._payload, list):
raise TypeError('Expected list, got %s' % type(self._payload))
payload = self._payload
# cte might be a Header, so for now stringify it.
cte = str(self.get('content-transfer-encoding', '')).lower()
# payload may be bytes here.
if isinstance(payload, str):
payload = str(payload) # for Python-Future, so surrogateescape works
if utils._has_surrogates(payload):
bpayload = payload.encode('ascii', 'surrogateescape')
if not decode:
try:
payload = bpayload.decode(self.get_param('charset', 'ascii'), 'replace')
except LookupError:
payload = bpayload.decode('ascii', 'replace')
elif decode:
try:
bpayload = payload.encode('ascii')
except UnicodeError:
# This won't happen for RFC compliant messages (messages
# containing only ASCII codepoints in the unicode input).
# If it does happen, turn the string into bytes in a way
# guaranteed not to fail.
bpayload = payload.encode('raw-unicode-escape')
if not decode:
return payload
if cte == 'quoted-printable':
return utils._qdecode(bpayload)
elif cte == 'base64':
# XXX: this is a bit of a hack; decode_b should probably be factored
# out somewhere, but I haven't figured out where yet.
value, defects = decode_b(b''.join(bpayload.splitlines()))
for defect in defects:
self.policy.handle_defect(self, defect)
return value
elif cte in ('x-uuencode', 'uuencode', 'uue', 'x-uue'):
in_file = BytesIO(bpayload)
out_file = BytesIO()
try:
uu.decode(in_file, out_file, quiet=True)
return out_file.getvalue()
except uu.Error:
# Some decoding problem
return bpayload
if isinstance(payload, str):
return bpayload
return payload
def set_payload(self, payload, charset=None):
"""Set the payload to the given value.
Optional charset sets the message's default character set. See
set_charset() for details.
"""
self._payload = payload
if charset is not None:
self.set_charset(charset)
def set_charset(self, charset):
"""Set the charset of the payload to a given character set.
charset can be a Charset instance, a string naming a character set, or
None. If it is a string it will be converted to a Charset instance.
If charset is None, the charset parameter will be removed from the
Content-Type field. Anything else will generate a TypeError.
The message will be assumed to be of type text/* encoded with
charset.input_charset. It will be converted to charset.output_charset
and encoded properly, if needed, when generating the plain text
representation of the message. MIME headers (MIME-Version,
Content-Type, Content-Transfer-Encoding) will be added as needed.
"""
if charset is None:
self.del_param('charset')
self._charset = None
return
if not isinstance(charset, Charset):
charset = Charset(charset)
self._charset = charset
if 'MIME-Version' not in self:
self.add_header('MIME-Version', '1.0')
if 'Content-Type' not in self:
self.add_header('Content-Type', 'text/plain',
charset=charset.get_output_charset())
else:
self.set_param('charset', charset.get_output_charset())
if charset != charset.get_output_charset():
self._payload = charset.body_encode(self._payload)
if 'Content-Transfer-Encoding' not in self:
cte = charset.get_body_encoding()
try:
cte(self)
except TypeError:
self._payload = charset.body_encode(self._payload)
self.add_header('Content-Transfer-Encoding', cte)
def get_charset(self):
"""Return the Charset instance associated with the message's payload.
"""
return self._charset
#
# MAPPING INTERFACE (partial)
#
def __len__(self):
"""Return the total number of headers, including duplicates."""
return len(self._headers)
def __getitem__(self, name):
"""Get a header value.
Return None if the header is missing instead of raising an exception.
Note that if the header appeared multiple times, exactly which
occurrence gets returned is undefined. Use get_all() to get all
the values matching a header field name.
"""
return self.get(name)
def __setitem__(self, name, val):
"""Set the value of a header.
Note: this does not overwrite an existing header with the same field
name. Use __delitem__() first to delete any existing headers.
"""
max_count = self.policy.header_max_count(name)
if max_count:
lname = name.lower()
found = 0
for k, v in self._headers:
if k.lower() == lname:
found += 1
if found >= max_count:
raise ValueError("There may be at most {} {} headers "
"in a message".format(max_count, name))
self._headers.append(self.policy.header_store_parse(name, val))
def __delitem__(self, name):
"""Delete all occurrences of a header, if present.
Does not raise an exception if the header is missing.
"""
name = name.lower()
newheaders = list()
for k, v in self._headers:
if k.lower() != name:
newheaders.append((k, v))
self._headers = newheaders
def __contains__(self, name):
return name.lower() in [k.lower() for k, v in self._headers]
def __iter__(self):
for field, value in self._headers:
yield field
def keys(self):
"""Return a list of all the message's header field names.
These will be sorted in the order they appeared in the original
message, or were added to the message, and may contain duplicates.
Any fields deleted and re-inserted are always appended to the header
list.
"""
return [k for k, v in self._headers]
def values(self):
"""Return a list of all the message's header values.
These will be sorted in the order they appeared in the original
message, or were added to the message, and may contain duplicates.
Any fields deleted and re-inserted are always appended to the header
list.
"""
return [self.policy.header_fetch_parse(k, v)
for k, v in self._headers]
def items(self):
"""Get all the message's header fields and values.
These will be sorted in the order they appeared in the original
message, or were added to the message, and may contain duplicates.
Any fields deleted and re-inserted are always appended to the header
list.
"""
return [(k, self.policy.header_fetch_parse(k, v))
for k, v in self._headers]
def get(self, name, failobj=None):
"""Get a header value.
Like __getitem__() but return failobj instead of None when the field
is missing.
"""
name = name.lower()
for k, v in self._headers:
if k.lower() == name:
return self.policy.header_fetch_parse(k, v)
return failobj
#
# "Internal" methods (public API, but only intended for use by a parser
# or generator, not normal application code.
#
def set_raw(self, name, value):
"""Store name and value in the model without modification.
This is an "internal" API, intended only for use by a parser.
"""
self._headers.append((name, value))
def raw_items(self):
"""Return the (name, value) header pairs without modification.
This is an "internal" API, intended only for use by a generator.
"""
return iter(self._headers.copy())
#
# Additional useful stuff
#
def get_all(self, name, failobj=None):
"""Return a list of all the values for the named field.
These will be sorted in the order they appeared in the original
message, and may contain duplicates. Any fields deleted and
re-inserted are always appended to the header list.
If no such fields exist, failobj is returned (defaults to None).
"""
values = []
name = name.lower()
for k, v in self._headers:
if k.lower() == name:
values.append(self.policy.header_fetch_parse(k, v))
if not values:
return failobj
return values
def add_header(self, _name, _value, **_params):
"""Extended header setting.
name is the header field to add. keyword arguments can be used to set
additional parameters for the header field, with underscores converted
to dashes. Normally the parameter will be added as key="value" unless
value is None, in which case only the key will be added. If a
parameter value contains non-ASCII characters it can be specified as a
three-tuple of (charset, language, value), in which case it will be
encoded according to RFC2231 rules. Otherwise it will be encoded using
the utf-8 charset and a language of ''.
Examples:
msg.add_header('content-disposition', 'attachment', filename='bud.gif')
msg.add_header('content-disposition', 'attachment',
filename=('utf-8', '', 'Fußballer.ppt'))
msg.add_header('content-disposition', 'attachment',
filename='Fußballer.ppt'))
"""
parts = []
for k, v in _params.items():
if v is None:
parts.append(k.replace('_', '-'))
else:
parts.append(_formatparam(k.replace('_', '-'), v))
if _value is not None:
parts.insert(0, _value)
self[_name] = SEMISPACE.join(parts)
def replace_header(self, _name, _value):
"""Replace a header.
Replace the first matching header found in the message, retaining
header order and case. If no matching header was found, a KeyError is
raised.
"""
_name = _name.lower()
for i, (k, v) in zip(range(len(self._headers)), self._headers):
if k.lower() == _name:
self._headers[i] = self.policy.header_store_parse(k, _value)
break
else:
raise KeyError(_name)
#
# Use these three methods instead of the three above.
#
def get_content_type(self):
"""Return the message's content type.
The returned string is coerced to lower case of the form
`maintype/subtype'. If there was no Content-Type header in the
message, the default type as given by get_default_type() will be
returned. Since according to RFC 2045, messages always have a default
type this will always return a value.
RFC 2045 defines a message's default type to be text/plain unless it
appears inside a multipart/digest container, in which case it would be
message/rfc822.
"""
missing = object()
value = self.get('content-type', missing)
if value is missing:
# This should have no parameters
return self.get_default_type()
ctype = _splitparam(value)[0].lower()
# RFC 2045, section 5.2 says if its invalid, use text/plain
if ctype.count('/') != 1:
return 'text/plain'
return ctype
def get_content_maintype(self):
"""Return the message's main content type.
This is the `maintype' part of the string returned by
get_content_type().
"""
ctype = self.get_content_type()
return ctype.split('/')[0]
def get_content_subtype(self):
"""Returns the message's sub-content type.
This is the `subtype' part of the string returned by
get_content_type().
"""
ctype = self.get_content_type()
return ctype.split('/')[1]
def get_default_type(self):
"""Return the `default' content type.
Most messages have a default content type of text/plain, except for
messages that are subparts of multipart/digest containers. Such
subparts have a default content type of message/rfc822.
"""
return self._default_type
def set_default_type(self, ctype):
"""Set the `default' content type.
ctype should be either "text/plain" or "message/rfc822", although this
is not enforced. The default content type is not stored in the
Content-Type header.
"""
self._default_type = ctype
def _get_params_preserve(self, failobj, header):
# Like get_params() but preserves the quoting of values. BAW:
# should this be part of the public interface?
missing = object()
value = self.get(header, missing)
if value is missing:
return failobj
params = []
for p in _parseparam(value):
try:
name, val = p.split('=', 1)
name = name.strip()
val = val.strip()
except ValueError:
# Must have been a bare attribute
name = p.strip()
val = ''
params.append((name, val))
params = utils.decode_params(params)
return params
def get_params(self, failobj=None, header='content-type', unquote=True):
"""Return the message's Content-Type parameters, as a list.
The elements of the returned list are 2-tuples of key/value pairs, as
split on the `=' sign. The left hand side of the `=' is the key,
while the right hand side is the value. If there is no `=' sign in
the parameter the value is the empty string. The value is as
described in the get_param() method.
Optional failobj is the object to return if there is no Content-Type
header. Optional header is the header to search instead of
Content-Type. If unquote is True, the value is unquoted.
"""
missing = object()
params = self._get_params_preserve(missing, header)
if params is missing:
return failobj
if unquote:
return [(k, _unquotevalue(v)) for k, v in params]
else:
return params
def get_param(self, param, failobj=None, header='content-type',
unquote=True):
"""Return the parameter value if found in the Content-Type header.
Optional failobj is the object to return if there is no Content-Type
header, or the Content-Type header has no such parameter. Optional
header is the header to search instead of Content-Type.
Parameter keys are always compared case insensitively. The return
value can either be a string, or a 3-tuple if the parameter was RFC
2231 encoded. When it's a 3-tuple, the elements of the value are of
the form (CHARSET, LANGUAGE, VALUE). Note that both CHARSET and
LANGUAGE can be None, in which case you should consider VALUE to be
encoded in the us-ascii charset. You can usually ignore LANGUAGE.
The parameter value (either the returned string, or the VALUE item in
the 3-tuple) is always unquoted, unless unquote is set to False.
If your application doesn't care whether the parameter was RFC 2231
encoded, it can turn the return value into a string as follows:
param = msg.get_param('foo')
param = email.utils.collapse_rfc2231_value(rawparam)
"""
if header not in self:
return failobj
for k, v in self._get_params_preserve(failobj, header):
if k.lower() == param.lower():
if unquote:
return _unquotevalue(v)
else:
return v
return failobj
def set_param(self, param, value, header='Content-Type', requote=True,
charset=None, language=''):
"""Set a parameter in the Content-Type header.
If the parameter already exists in the header, its value will be
replaced with the new value.
If header is Content-Type and has not yet been defined for this
message, it will be set to "text/plain" and the new parameter and
value will be appended as per RFC 2045.
An alternate header can specified in the header argument, and all
parameters will be quoted as necessary unless requote is False.
If charset is specified, the parameter will be encoded according to RFC
2231. Optional language specifies the RFC 2231 language, defaulting
to the empty string. Both charset and language should be strings.
"""
if not isinstance(value, tuple) and charset:
value = (charset, language, value)
if header not in self and header.lower() == 'content-type':
ctype = 'text/plain'
else:
ctype = self.get(header)
if not self.get_param(param, header=header):
if not ctype:
ctype = _formatparam(param, value, requote)
else:
ctype = SEMISPACE.join(
[ctype, _formatparam(param, value, requote)])
else:
ctype = ''
for old_param, old_value in self.get_params(header=header,
unquote=requote):
append_param = ''
if old_param.lower() == param.lower():
append_param = _formatparam(param, value, requote)
else:
append_param = _formatparam(old_param, old_value, requote)
if not ctype:
ctype = append_param
else:
ctype = SEMISPACE.join([ctype, append_param])
if ctype != self.get(header):
del self[header]
self[header] = ctype
def del_param(self, param, header='content-type', requote=True):
"""Remove the given parameter completely from the Content-Type header.
The header will be re-written in place without the parameter or its
value. All values will be quoted as necessary unless requote is
False. Optional header specifies an alternative to the Content-Type
header.
"""
if header not in self:
return
new_ctype = ''
for p, v in self.get_params(header=header, unquote=requote):
if p.lower() != param.lower():
if not new_ctype:
new_ctype = _formatparam(p, v, requote)
else:
new_ctype = SEMISPACE.join([new_ctype,
_formatparam(p, v, requote)])
if new_ctype != self.get(header):
del self[header]
self[header] = new_ctype
def set_type(self, type, header='Content-Type', requote=True):
"""Set the main type and subtype for the Content-Type header.
type must be a string in the form "maintype/subtype", otherwise a
ValueError is raised.
This method replaces the Content-Type header, keeping all the
parameters in place. If requote is False, this leaves the existing
header's quoting as is. Otherwise, the parameters will be quoted (the
default).
An alternative header can be specified in the header argument. When
the Content-Type header is set, we'll always also add a MIME-Version
header.
"""
# BAW: should we be strict?
if not type.count('/') == 1:
raise ValueError
# Set the Content-Type, you get a MIME-Version
if header.lower() == 'content-type':
del self['mime-version']
self['MIME-Version'] = '1.0'
if header not in self:
self[header] = type
return
params = self.get_params(header=header, unquote=requote)
del self[header]
self[header] = type
# Skip the first param; it's the old type.
for p, v in params[1:]:
self.set_param(p, v, header, requote)
def get_filename(self, failobj=None):
"""Return the filename associated with the payload if present.
The filename is extracted from the Content-Disposition header's
`filename' parameter, and it is unquoted. If that header is missing
the `filename' parameter, this method falls back to looking for the
`name' parameter.
"""
missing = object()
filename = self.get_param('filename', missing, 'content-disposition')
if filename is missing:
filename = self.get_param('name', missing, 'content-type')
if filename is missing:
return failobj
return utils.collapse_rfc2231_value(filename).strip()
def get_boundary(self, failobj=None):
"""Return the boundary associated with the payload if present.
The boundary is extracted from the Content-Type header's `boundary'
parameter, and it is unquoted.
"""
missing = object()
boundary = self.get_param('boundary', missing)
if boundary is missing:
return failobj
# RFC 2046 says that boundaries may begin but not end in w/s
return utils.collapse_rfc2231_value(boundary).rstrip()
def set_boundary(self, boundary):
"""Set the boundary parameter in Content-Type to 'boundary'.
This is subtly different than deleting the Content-Type header and
adding a new one with a new boundary parameter via add_header(). The
main difference is that using the set_boundary() method preserves the
order of the Content-Type header in the original message.
HeaderParseError is raised if the message has no Content-Type header.
"""
missing = object()
params = self._get_params_preserve(missing, 'content-type')
if params is missing:
# There was no Content-Type header, and we don't know what type
# to set it to, so raise an exception.
raise errors.HeaderParseError('No Content-Type header found')
newparams = []
foundp = False
for pk, pv in params:
if pk.lower() == 'boundary':
newparams.append(('boundary', '"%s"' % boundary))
foundp = True
else:
newparams.append((pk, pv))
if not foundp:
# The original Content-Type header had no boundary attribute.
# Tack one on the end. BAW: should we raise an exception
# instead???
newparams.append(('boundary', '"%s"' % boundary))
# Replace the existing Content-Type header with the new value
newheaders = []
for h, v in self._headers:
if h.lower() == 'content-type':
parts = []
for k, v in newparams:
if v == '':
parts.append(k)
else:
parts.append('%s=%s' % (k, v))
val = SEMISPACE.join(parts)
newheaders.append(self.policy.header_store_parse(h, val))
else:
newheaders.append((h, v))
self._headers = newheaders
def get_content_charset(self, failobj=None):
"""Return the charset parameter of the Content-Type header.
The returned string is always coerced to lower case. If there is no
Content-Type header, or if that header has no charset parameter,
failobj is returned.
"""
missing = object()
charset = self.get_param('charset', missing)
if charset is missing:
return failobj
if isinstance(charset, tuple):
# RFC 2231 encoded, so decode it, and it better end up as ascii.
pcharset = charset[0] or 'us-ascii'
try:
# LookupError will be raised if the charset isn't known to
# Python. UnicodeError will be raised if the encoded text
# contains a character not in the charset.
as_bytes = charset[2].encode('raw-unicode-escape')
charset = str(as_bytes, pcharset)
except (LookupError, UnicodeError):
charset = charset[2]
# charset characters must be in us-ascii range
try:
charset.encode('us-ascii')
except UnicodeError:
return failobj
# RFC 2046, $4.1.2 says charsets are not case sensitive
return charset.lower()
def get_charsets(self, failobj=None):
"""Return a list containing the charset(s) used in this message.
The returned list of items describes the Content-Type headers'
charset parameter for this message and all the subparts in its
payload.
Each item will either be a string (the value of the charset parameter
in the Content-Type header of that part) or the value of the
'failobj' parameter (defaults to None), if the part does not have a
main MIME type of "text", or the charset is not defined.
The list will contain one string for each part of the message, plus
one for the container message (i.e. self), so that a non-multipart
message will still return a list of length 1.
"""
return [part.get_content_charset(failobj) for part in self.walk()]
# I.e. def walk(self): ...
from future.backports.email.iterators import walk

View file

@ -0,0 +1,39 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Keith Dart
# Contact: email-sig@python.org
"""Class representing application/* type MIME documents."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.backports.email import encoders
from future.backports.email.mime.nonmultipart import MIMENonMultipart
__all__ = ["MIMEApplication"]
class MIMEApplication(MIMENonMultipart):
"""Class for generating application/* MIME documents."""
def __init__(self, _data, _subtype='octet-stream',
_encoder=encoders.encode_base64, **_params):
"""Create an application/* type MIME document.
_data is a string containing the raw application data.
_subtype is the MIME content type subtype, defaulting to
'octet-stream'.
_encoder is a function which will perform the actual encoding for
transport of the application data, defaulting to base64 encoding.
Any additional keyword arguments are passed to the base class
constructor, which turns them into parameters on the Content-Type
header.
"""
if _subtype is None:
raise TypeError('Invalid application MIME subtype')
MIMENonMultipart.__init__(self, 'application', _subtype, **_params)
self.set_payload(_data)
_encoder(self)

View file

@ -0,0 +1,74 @@
# Copyright (C) 2001-2007 Python Software Foundation
# Author: Anthony Baxter
# Contact: email-sig@python.org
"""Class representing audio/* type MIME documents."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
__all__ = ['MIMEAudio']
import sndhdr
from io import BytesIO
from future.backports.email import encoders
from future.backports.email.mime.nonmultipart import MIMENonMultipart
_sndhdr_MIMEmap = {'au' : 'basic',
'wav' :'x-wav',
'aiff':'x-aiff',
'aifc':'x-aiff',
}
# There are others in sndhdr that don't have MIME types. :(
# Additional ones to be added to sndhdr? midi, mp3, realaudio, wma??
def _whatsnd(data):
"""Try to identify a sound file type.
sndhdr.what() has a pretty cruddy interface, unfortunately. This is why
we re-do it here. It would be easier to reverse engineer the Unix 'file'
command and use the standard 'magic' file, as shipped with a modern Unix.
"""
hdr = data[:512]
fakefile = BytesIO(hdr)
for testfn in sndhdr.tests:
res = testfn(hdr, fakefile)
if res is not None:
return _sndhdr_MIMEmap.get(res[0])
return None
class MIMEAudio(MIMENonMultipart):
"""Class for generating audio/* MIME documents."""
def __init__(self, _audiodata, _subtype=None,
_encoder=encoders.encode_base64, **_params):
"""Create an audio/* type MIME document.
_audiodata is a string containing the raw audio data. If this data
can be decoded by the standard Python `sndhdr' module, then the
subtype will be automatically included in the Content-Type header.
Otherwise, you can specify the specific audio subtype via the
_subtype parameter. If _subtype is not given, and no subtype can be
guessed, a TypeError is raised.
_encoder is a function which will perform the actual encoding for
transport of the image data. It takes one argument, which is this
Image instance. It should use get_payload() and set_payload() to
change the payload to the encoded form. It should also add any
Content-Transfer-Encoding or other headers to the message as
necessary. The default encoding is Base64.
Any additional keyword arguments are passed to the base class
constructor, which turns them into parameters on the Content-Type
header.
"""
if _subtype is None:
_subtype = _whatsnd(_audiodata)
if _subtype is None:
raise TypeError('Could not find audio MIME subtype')
MIMENonMultipart.__init__(self, 'audio', _subtype, **_params)
self.set_payload(_audiodata)
_encoder(self)

View file

@ -0,0 +1,25 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Base class for MIME specializations."""
from __future__ import absolute_import, division, unicode_literals
from future.backports.email import message
__all__ = ['MIMEBase']
class MIMEBase(message.Message):
"""Base class for MIME specializations."""
def __init__(self, _maintype, _subtype, **_params):
"""This constructor adds a Content-Type: and a MIME-Version: header.
The Content-Type: header is taken from the _maintype and _subtype
arguments. Additional parameters for this header are taken from the
keyword arguments.
"""
message.Message.__init__(self)
ctype = '%s/%s' % (_maintype, _subtype)
self.add_header('Content-Type', ctype, **_params)
self['MIME-Version'] = '1.0'

View file

@ -0,0 +1,48 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Class representing image/* type MIME documents."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
__all__ = ['MIMEImage']
import imghdr
from future.backports.email import encoders
from future.backports.email.mime.nonmultipart import MIMENonMultipart
class MIMEImage(MIMENonMultipart):
"""Class for generating image/* type MIME documents."""
def __init__(self, _imagedata, _subtype=None,
_encoder=encoders.encode_base64, **_params):
"""Create an image/* type MIME document.
_imagedata is a string containing the raw image data. If this data
can be decoded by the standard Python `imghdr' module, then the
subtype will be automatically included in the Content-Type header.
Otherwise, you can specify the specific image subtype via the _subtype
parameter.
_encoder is a function which will perform the actual encoding for
transport of the image data. It takes one argument, which is this
Image instance. It should use get_payload() and set_payload() to
change the payload to the encoded form. It should also add any
Content-Transfer-Encoding or other headers to the message as
necessary. The default encoding is Base64.
Any additional keyword arguments are passed to the base class
constructor, which turns them into parameters on the Content-Type
header.
"""
if _subtype is None:
_subtype = imghdr.what(None, _imagedata)
if _subtype is None:
raise TypeError('Could not guess image MIME subtype')
MIMENonMultipart.__init__(self, 'image', _subtype, **_params)
self.set_payload(_imagedata)
_encoder(self)

View file

@ -0,0 +1,36 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Class representing message/* MIME documents."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
__all__ = ['MIMEMessage']
from future.backports.email import message
from future.backports.email.mime.nonmultipart import MIMENonMultipart
class MIMEMessage(MIMENonMultipart):
"""Class representing message/* MIME documents."""
def __init__(self, _msg, _subtype='rfc822'):
"""Create a message/* type MIME document.
_msg is a message object and must be an instance of Message, or a
derived class of Message, otherwise a TypeError is raised.
Optional _subtype defines the subtype of the contained message. The
default is "rfc822" (this is defined by the MIME standard, even though
the term "rfc822" is technically outdated by RFC 2822).
"""
MIMENonMultipart.__init__(self, 'message', _subtype)
if not isinstance(_msg, message.Message):
raise TypeError('Argument is not an instance of Message')
# It's convenient to use this base class method. We need to do it
# this way or we'll get an exception
message.Message.attach(self, _msg)
# And be sure our default type is set correctly
self.set_default_type('message/rfc822')

View file

@ -0,0 +1,49 @@
# Copyright (C) 2002-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Base class for MIME multipart/* type messages."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
__all__ = ['MIMEMultipart']
from future.backports.email.mime.base import MIMEBase
class MIMEMultipart(MIMEBase):
"""Base class for MIME multipart/* type messages."""
def __init__(self, _subtype='mixed', boundary=None, _subparts=None,
**_params):
"""Creates a multipart/* type message.
By default, creates a multipart/mixed message, with proper
Content-Type and MIME-Version headers.
_subtype is the subtype of the multipart content type, defaulting to
`mixed'.
boundary is the multipart boundary string. By default it is
calculated as needed.
_subparts is a sequence of initial subparts for the payload. It
must be an iterable object, such as a list. You can always
attach new subparts to the message by using the attach() method.
Additional parameters for the Content-Type header are taken from the
keyword arguments (or passed into the _params argument).
"""
MIMEBase.__init__(self, 'multipart', _subtype, **_params)
# Initialise _payload to an empty list as the Message superclass's
# implementation of is_multipart assumes that _payload is a list for
# multipart messages.
self._payload = []
if _subparts:
for p in _subparts:
self.attach(p)
if boundary:
self.set_boundary(boundary)

View file

@ -0,0 +1,24 @@
# Copyright (C) 2002-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Base class for MIME type messages that are not multipart."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
__all__ = ['MIMENonMultipart']
from future.backports.email import errors
from future.backports.email.mime.base import MIMEBase
class MIMENonMultipart(MIMEBase):
"""Base class for MIME multipart/* type messages."""
def attach(self, payload):
# The public API prohibits attaching multiple subparts to MIMEBase
# derived subtypes since none of them are, by definition, of content
# type multipart/*
raise errors.MultipartConversionError(
'Cannot attach additional subparts to non-multipart/*')

View file

@ -0,0 +1,44 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Class representing text/* type MIME documents."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
__all__ = ['MIMEText']
from future.backports.email.encoders import encode_7or8bit
from future.backports.email.mime.nonmultipart import MIMENonMultipart
class MIMEText(MIMENonMultipart):
"""Class for generating text/* type MIME documents."""
def __init__(self, _text, _subtype='plain', _charset=None):
"""Create a text/* type MIME document.
_text is the string for this message object.
_subtype is the MIME sub content type, defaulting to "plain".
_charset is the character set parameter added to the Content-Type
header. This defaults to "us-ascii". Note that as a side-effect, the
Content-Transfer-Encoding header will also be set.
"""
# If no _charset was specified, check to see if there are non-ascii
# characters present. If not, use 'us-ascii', otherwise use utf-8.
# XXX: This can be removed once #7304 is fixed.
if _charset is None:
try:
_text.encode('us-ascii')
_charset = 'us-ascii'
except UnicodeEncodeError:
_charset = 'utf-8'
MIMENonMultipart.__init__(self, 'text', _subtype,
**{'charset': _charset})
self.set_payload(_text, _charset)

View file

@ -0,0 +1,135 @@
# Copyright (C) 2001-2007 Python Software Foundation
# Author: Barry Warsaw, Thomas Wouters, Anthony Baxter
# Contact: email-sig@python.org
"""A parser of RFC 2822 and MIME email messages."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
__all__ = ['Parser', 'HeaderParser', 'BytesParser', 'BytesHeaderParser']
import warnings
from io import StringIO, TextIOWrapper
from future.backports.email.feedparser import FeedParser, BytesFeedParser
from future.backports.email.message import Message
from future.backports.email._policybase import compat32
class Parser(object):
def __init__(self, _class=Message, **_3to2kwargs):
"""Parser of RFC 2822 and MIME email messages.
Creates an in-memory object tree representing the email message, which
can then be manipulated and turned over to a Generator to return the
textual representation of the message.
The string must be formatted as a block of RFC 2822 headers and header
continuation lines, optionally preceeded by a `Unix-from' header. The
header block is terminated either by the end of the string or by a
blank line.
_class is the class to instantiate for new message objects when they
must be created. This class must have a constructor that can take
zero arguments. Default is Message.Message.
The policy keyword specifies a policy object that controls a number of
aspects of the parser's operation. The default policy maintains
backward compatibility.
"""
if 'policy' in _3to2kwargs: policy = _3to2kwargs['policy']; del _3to2kwargs['policy']
else: policy = compat32
self._class = _class
self.policy = policy
def parse(self, fp, headersonly=False):
"""Create a message structure from the data in a file.
Reads all the data from the file and returns the root of the message
structure. Optional headersonly is a flag specifying whether to stop
parsing after reading the headers or not. The default is False,
meaning it parses the entire contents of the file.
"""
feedparser = FeedParser(self._class, policy=self.policy)
if headersonly:
feedparser._set_headersonly()
while True:
data = fp.read(8192)
if not data:
break
feedparser.feed(data)
return feedparser.close()
def parsestr(self, text, headersonly=False):
"""Create a message structure from a string.
Returns the root of the message structure. Optional headersonly is a
flag specifying whether to stop parsing after reading the headers or
not. The default is False, meaning it parses the entire contents of
the file.
"""
return self.parse(StringIO(text), headersonly=headersonly)
class HeaderParser(Parser):
def parse(self, fp, headersonly=True):
return Parser.parse(self, fp, True)
def parsestr(self, text, headersonly=True):
return Parser.parsestr(self, text, True)
class BytesParser(object):
def __init__(self, *args, **kw):
"""Parser of binary RFC 2822 and MIME email messages.
Creates an in-memory object tree representing the email message, which
can then be manipulated and turned over to a Generator to return the
textual representation of the message.
The input must be formatted as a block of RFC 2822 headers and header
continuation lines, optionally preceeded by a `Unix-from' header. The
header block is terminated either by the end of the input or by a
blank line.
_class is the class to instantiate for new message objects when they
must be created. This class must have a constructor that can take
zero arguments. Default is Message.Message.
"""
self.parser = Parser(*args, **kw)
def parse(self, fp, headersonly=False):
"""Create a message structure from the data in a binary file.
Reads all the data from the file and returns the root of the message
structure. Optional headersonly is a flag specifying whether to stop
parsing after reading the headers or not. The default is False,
meaning it parses the entire contents of the file.
"""
fp = TextIOWrapper(fp, encoding='ascii', errors='surrogateescape')
with fp:
return self.parser.parse(fp, headersonly)
def parsebytes(self, text, headersonly=False):
"""Create a message structure from a byte string.
Returns the root of the message structure. Optional headersonly is a
flag specifying whether to stop parsing after reading the headers or
not. The default is False, meaning it parses the entire contents of
the file.
"""
text = text.decode('ASCII', errors='surrogateescape')
return self.parser.parsestr(text, headersonly)
class BytesHeaderParser(BytesParser):
def parse(self, fp, headersonly=True):
return BytesParser.parse(self, fp, headersonly=True)
def parsebytes(self, text, headersonly=True):
return BytesParser.parsebytes(self, text, headersonly=True)

View file

@ -0,0 +1,193 @@
"""This will be the home for the policy that hooks in the new
code that adds all the email6 features.
"""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import super
from future.standard_library.email._policybase import (Policy, Compat32,
compat32, _extend_docstrings)
from future.standard_library.email.utils import _has_surrogates
from future.standard_library.email.headerregistry import HeaderRegistry as HeaderRegistry
__all__ = [
'Compat32',
'compat32',
'Policy',
'EmailPolicy',
'default',
'strict',
'SMTP',
'HTTP',
]
@_extend_docstrings
class EmailPolicy(Policy):
"""+
PROVISIONAL
The API extensions enabled by this policy are currently provisional.
Refer to the documentation for details.
This policy adds new header parsing and folding algorithms. Instead of
simple strings, headers are custom objects with custom attributes
depending on the type of the field. The folding algorithm fully
implements RFCs 2047 and 5322.
In addition to the settable attributes listed above that apply to
all Policies, this policy adds the following additional attributes:
refold_source -- if the value for a header in the Message object
came from the parsing of some source, this attribute
indicates whether or not a generator should refold
that value when transforming the message back into
stream form. The possible values are:
none -- all source values use original folding
long -- source values that have any line that is
longer than max_line_length will be
refolded
all -- all values are refolded.
The default is 'long'.
header_factory -- a callable that takes two arguments, 'name' and
'value', where 'name' is a header field name and
'value' is an unfolded header field value, and
returns a string-like object that represents that
header. A default header_factory is provided that
understands some of the RFC5322 header field types.
(Currently address fields and date fields have
special treatment, while all other fields are
treated as unstructured. This list will be
completed before the extension is marked stable.)
"""
refold_source = 'long'
header_factory = HeaderRegistry()
def __init__(self, **kw):
# Ensure that each new instance gets a unique header factory
# (as opposed to clones, which share the factory).
if 'header_factory' not in kw:
object.__setattr__(self, 'header_factory', HeaderRegistry())
super().__init__(**kw)
def header_max_count(self, name):
"""+
The implementation for this class returns the max_count attribute from
the specialized header class that would be used to construct a header
of type 'name'.
"""
return self.header_factory[name].max_count
# The logic of the next three methods is chosen such that it is possible to
# switch a Message object between a Compat32 policy and a policy derived
# from this class and have the results stay consistent. This allows a
# Message object constructed with this policy to be passed to a library
# that only handles Compat32 objects, or to receive such an object and
# convert it to use the newer style by just changing its policy. It is
# also chosen because it postpones the relatively expensive full rfc5322
# parse until as late as possible when parsing from source, since in many
# applications only a few headers will actually be inspected.
def header_source_parse(self, sourcelines):
"""+
The name is parsed as everything up to the ':' and returned unmodified.
The value is determined by stripping leading whitespace off the
remainder of the first line, joining all subsequent lines together, and
stripping any trailing carriage return or linefeed characters. (This
is the same as Compat32).
"""
name, value = sourcelines[0].split(':', 1)
value = value.lstrip(' \t') + ''.join(sourcelines[1:])
return (name, value.rstrip('\r\n'))
def header_store_parse(self, name, value):
"""+
The name is returned unchanged. If the input value has a 'name'
attribute and it matches the name ignoring case, the value is returned
unchanged. Otherwise the name and value are passed to header_factory
method, and the resulting custom header object is returned as the
value. In this case a ValueError is raised if the input value contains
CR or LF characters.
"""
if hasattr(value, 'name') and value.name.lower() == name.lower():
return (name, value)
if isinstance(value, str) and len(value.splitlines())>1:
raise ValueError("Header values may not contain linefeed "
"or carriage return characters")
return (name, self.header_factory(name, value))
def header_fetch_parse(self, name, value):
"""+
If the value has a 'name' attribute, it is returned to unmodified.
Otherwise the name and the value with any linesep characters removed
are passed to the header_factory method, and the resulting custom
header object is returned. Any surrogateescaped bytes get turned
into the unicode unknown-character glyph.
"""
if hasattr(value, 'name'):
return value
return self.header_factory(name, ''.join(value.splitlines()))
def fold(self, name, value):
"""+
Header folding is controlled by the refold_source policy setting. A
value is considered to be a 'source value' if and only if it does not
have a 'name' attribute (having a 'name' attribute means it is a header
object of some sort). If a source value needs to be refolded according
to the policy, it is converted into a custom header object by passing
the name and the value with any linesep characters removed to the
header_factory method. Folding of a custom header object is done by
calling its fold method with the current policy.
Source values are split into lines using splitlines. If the value is
not to be refolded, the lines are rejoined using the linesep from the
policy and returned. The exception is lines containing non-ascii
binary data. In that case the value is refolded regardless of the
refold_source setting, which causes the binary data to be CTE encoded
using the unknown-8bit charset.
"""
return self._fold(name, value, refold_binary=True)
def fold_binary(self, name, value):
"""+
The same as fold if cte_type is 7bit, except that the returned value is
bytes.
If cte_type is 8bit, non-ASCII binary data is converted back into
bytes. Headers with binary data are not refolded, regardless of the
refold_header setting, since there is no way to know whether the binary
data consists of single byte characters or multibyte characters.
"""
folded = self._fold(name, value, refold_binary=self.cte_type=='7bit')
return folded.encode('ascii', 'surrogateescape')
def _fold(self, name, value, refold_binary=False):
if hasattr(value, 'name'):
return value.fold(policy=self)
maxlen = self.max_line_length if self.max_line_length else float('inf')
lines = value.splitlines()
refold = (self.refold_source == 'all' or
self.refold_source == 'long' and
(lines and len(lines[0])+len(name)+2 > maxlen or
any(len(x) > maxlen for x in lines[1:])))
if refold or refold_binary and _has_surrogates(value):
return self.header_factory(name, ''.join(lines)).fold(policy=self)
return name + ': ' + self.linesep.join(lines) + self.linesep
default = EmailPolicy()
# Make the default policy use the class default header_factory
del default.header_factory
strict = default.clone(raise_on_defect=True)
SMTP = default.clone(linesep='\r\n')
HTTP = default.clone(linesep='\r\n', max_line_length=None)

View file

@ -0,0 +1,326 @@
# Copyright (C) 2001-2006 Python Software Foundation
# Author: Ben Gertzfield
# Contact: email-sig@python.org
"""Quoted-printable content transfer encoding per RFCs 2045-2047.
This module handles the content transfer encoding method defined in RFC 2045
to encode US ASCII-like 8-bit data called `quoted-printable'. It is used to
safely encode text that is in a character set similar to the 7-bit US ASCII
character set, but that includes some 8-bit characters that are normally not
allowed in email bodies or headers.
Quoted-printable is very space-inefficient for encoding binary files; use the
email.base64mime module for that instead.
This module provides an interface to encode and decode both headers and bodies
with quoted-printable encoding.
RFC 2045 defines a method for including character set information in an
`encoded-word' in a header. This method is commonly used for 8-bit real names
in To:/From:/Cc: etc. fields, as well as Subject: lines.
This module does not do the line wrapping or end-of-line character
conversion necessary for proper internationalized headers; it only
does dumb encoding and decoding. To deal with the various line
wrapping issues, use the email.header module.
"""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future.builtins import bytes, chr, dict, int, range, super
__all__ = [
'body_decode',
'body_encode',
'body_length',
'decode',
'decodestring',
'header_decode',
'header_encode',
'header_length',
'quote',
'unquote',
]
import re
import io
from string import ascii_letters, digits, hexdigits
CRLF = '\r\n'
NL = '\n'
EMPTYSTRING = ''
# Build a mapping of octets to the expansion of that octet. Since we're only
# going to have 256 of these things, this isn't terribly inefficient
# space-wise. Remember that headers and bodies have different sets of safe
# characters. Initialize both maps with the full expansion, and then override
# the safe bytes with the more compact form.
_QUOPRI_HEADER_MAP = dict((c, '=%02X' % c) for c in range(256))
_QUOPRI_BODY_MAP = _QUOPRI_HEADER_MAP.copy()
# Safe header bytes which need no encoding.
for c in bytes(b'-!*+/' + ascii_letters.encode('ascii') + digits.encode('ascii')):
_QUOPRI_HEADER_MAP[c] = chr(c)
# Headers have one other special encoding; spaces become underscores.
_QUOPRI_HEADER_MAP[ord(' ')] = '_'
# Safe body bytes which need no encoding.
for c in bytes(b' !"#$%&\'()*+,-./0123456789:;<>'
b'?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`'
b'abcdefghijklmnopqrstuvwxyz{|}~\t'):
_QUOPRI_BODY_MAP[c] = chr(c)
# Helpers
def header_check(octet):
"""Return True if the octet should be escaped with header quopri."""
return chr(octet) != _QUOPRI_HEADER_MAP[octet]
def body_check(octet):
"""Return True if the octet should be escaped with body quopri."""
return chr(octet) != _QUOPRI_BODY_MAP[octet]
def header_length(bytearray):
"""Return a header quoted-printable encoding length.
Note that this does not include any RFC 2047 chrome added by
`header_encode()`.
:param bytearray: An array of bytes (a.k.a. octets).
:return: The length in bytes of the byte array when it is encoded with
quoted-printable for headers.
"""
return sum(len(_QUOPRI_HEADER_MAP[octet]) for octet in bytearray)
def body_length(bytearray):
"""Return a body quoted-printable encoding length.
:param bytearray: An array of bytes (a.k.a. octets).
:return: The length in bytes of the byte array when it is encoded with
quoted-printable for bodies.
"""
return sum(len(_QUOPRI_BODY_MAP[octet]) for octet in bytearray)
def _max_append(L, s, maxlen, extra=''):
if not isinstance(s, str):
s = chr(s)
if not L:
L.append(s.lstrip())
elif len(L[-1]) + len(s) <= maxlen:
L[-1] += extra + s
else:
L.append(s.lstrip())
def unquote(s):
"""Turn a string in the form =AB to the ASCII character with value 0xab"""
return chr(int(s[1:3], 16))
def quote(c):
return '=%02X' % ord(c)
def header_encode(header_bytes, charset='iso-8859-1'):
"""Encode a single header line with quoted-printable (like) encoding.
Defined in RFC 2045, this `Q' encoding is similar to quoted-printable, but
used specifically for email header fields to allow charsets with mostly 7
bit characters (and some 8 bit) to remain more or less readable in non-RFC
2045 aware mail clients.
charset names the character set to use in the RFC 2046 header. It
defaults to iso-8859-1.
"""
# Return empty headers as an empty string.
if not header_bytes:
return ''
# Iterate over every byte, encoding if necessary.
encoded = []
for octet in header_bytes:
encoded.append(_QUOPRI_HEADER_MAP[octet])
# Now add the RFC chrome to each encoded chunk and glue the chunks
# together.
return '=?%s?q?%s?=' % (charset, EMPTYSTRING.join(encoded))
class _body_accumulator(io.StringIO):
def __init__(self, maxlinelen, eol, *args, **kw):
super().__init__(*args, **kw)
self.eol = eol
self.maxlinelen = self.room = maxlinelen
def write_str(self, s):
"""Add string s to the accumulated body."""
self.write(s)
self.room -= len(s)
def newline(self):
"""Write eol, then start new line."""
self.write_str(self.eol)
self.room = self.maxlinelen
def write_soft_break(self):
"""Write a soft break, then start a new line."""
self.write_str('=')
self.newline()
def write_wrapped(self, s, extra_room=0):
"""Add a soft line break if needed, then write s."""
if self.room < len(s) + extra_room:
self.write_soft_break()
self.write_str(s)
def write_char(self, c, is_last_char):
if not is_last_char:
# Another character follows on this line, so we must leave
# extra room, either for it or a soft break, and whitespace
# need not be quoted.
self.write_wrapped(c, extra_room=1)
elif c not in ' \t':
# For this and remaining cases, no more characters follow,
# so there is no need to reserve extra room (since a hard
# break will immediately follow).
self.write_wrapped(c)
elif self.room >= 3:
# It's a whitespace character at end-of-line, and we have room
# for the three-character quoted encoding.
self.write(quote(c))
elif self.room == 2:
# There's room for the whitespace character and a soft break.
self.write(c)
self.write_soft_break()
else:
# There's room only for a soft break. The quoted whitespace
# will be the only content on the subsequent line.
self.write_soft_break()
self.write(quote(c))
def body_encode(body, maxlinelen=76, eol=NL):
"""Encode with quoted-printable, wrapping at maxlinelen characters.
Each line of encoded text will end with eol, which defaults to "\\n". Set
this to "\\r\\n" if you will be using the result of this function directly
in an email.
Each line will be wrapped at, at most, maxlinelen characters before the
eol string (maxlinelen defaults to 76 characters, the maximum value
permitted by RFC 2045). Long lines will have the 'soft line break'
quoted-printable character "=" appended to them, so the decoded text will
be identical to the original text.
The minimum maxlinelen is 4 to have room for a quoted character ("=XX")
followed by a soft line break. Smaller values will generate a
ValueError.
"""
if maxlinelen < 4:
raise ValueError("maxlinelen must be at least 4")
if not body:
return body
# The last line may or may not end in eol, but all other lines do.
last_has_eol = (body[-1] in '\r\n')
# This accumulator will make it easier to build the encoded body.
encoded_body = _body_accumulator(maxlinelen, eol)
lines = body.splitlines()
last_line_no = len(lines) - 1
for line_no, line in enumerate(lines):
last_char_index = len(line) - 1
for i, c in enumerate(line):
if body_check(ord(c)):
c = quote(c)
encoded_body.write_char(c, i==last_char_index)
# Add an eol if input line had eol. All input lines have eol except
# possibly the last one.
if line_no < last_line_no or last_has_eol:
encoded_body.newline()
return encoded_body.getvalue()
# BAW: I'm not sure if the intent was for the signature of this function to be
# the same as base64MIME.decode() or not...
def decode(encoded, eol=NL):
"""Decode a quoted-printable string.
Lines are separated with eol, which defaults to \\n.
"""
if not encoded:
return encoded
# BAW: see comment in encode() above. Again, we're building up the
# decoded string with string concatenation, which could be done much more
# efficiently.
decoded = ''
for line in encoded.splitlines():
line = line.rstrip()
if not line:
decoded += eol
continue
i = 0
n = len(line)
while i < n:
c = line[i]
if c != '=':
decoded += c
i += 1
# Otherwise, c == "=". Are we at the end of the line? If so, add
# a soft line break.
elif i+1 == n:
i += 1
continue
# Decode if in form =AB
elif i+2 < n and line[i+1] in hexdigits and line[i+2] in hexdigits:
decoded += unquote(line[i:i+3])
i += 3
# Otherwise, not in form =AB, pass literally
else:
decoded += c
i += 1
if i == n:
decoded += eol
# Special case if original string did not end with eol
if encoded[-1] not in '\r\n' and decoded.endswith(eol):
decoded = decoded[:-1]
return decoded
# For convenience and backwards compatibility w/ standard base64 module
body_decode = decode
decodestring = decode
def _unquote_match(match):
"""Turn a match in the form =AB to the ASCII character with value 0xab"""
s = match.group(0)
return unquote(s)
# Header decoding is done a bit differently
def header_decode(s):
"""Decode a string encoded with RFC 2045 MIME header `Q' encoding.
This function does not parse a full MIME header value encoded with
quoted-printable (like =?iso-8895-1?q?Hello_World?=) -- please use
the high level email.header class for that functionality.
"""
s = s.replace('_', ' ')
return re.sub(r'=[a-fA-F0-9]{2}', _unquote_match, s, re.ASCII)

View file

@ -0,0 +1,400 @@
# Copyright (C) 2001-2010 Python Software Foundation
# Author: Barry Warsaw
# Contact: email-sig@python.org
"""Miscellaneous utilities."""
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future import utils
from future.builtins import bytes, int, str
__all__ = [
'collapse_rfc2231_value',
'decode_params',
'decode_rfc2231',
'encode_rfc2231',
'formataddr',
'formatdate',
'format_datetime',
'getaddresses',
'make_msgid',
'mktime_tz',
'parseaddr',
'parsedate',
'parsedate_tz',
'parsedate_to_datetime',
'unquote',
]
import os
import re
if utils.PY2:
re.ASCII = 0
import time
import base64
import random
import socket
from future.backports import datetime
from future.backports.urllib.parse import quote as url_quote, unquote as url_unquote
import warnings
from io import StringIO
from future.backports.email._parseaddr import quote
from future.backports.email._parseaddr import AddressList as _AddressList
from future.backports.email._parseaddr import mktime_tz
from future.backports.email._parseaddr import parsedate, parsedate_tz, _parsedate_tz
from quopri import decodestring as _qdecode
# Intrapackage imports
from future.backports.email.encoders import _bencode, _qencode
from future.backports.email.charset import Charset
COMMASPACE = ', '
EMPTYSTRING = ''
UEMPTYSTRING = ''
CRLF = '\r\n'
TICK = "'"
specialsre = re.compile(r'[][\\()<>@,:;".]')
escapesre = re.compile(r'[\\"]')
# How to figure out if we are processing strings that come from a byte
# source with undecodable characters.
_has_surrogates = re.compile(
'([^\ud800-\udbff]|\A)[\udc00-\udfff]([^\udc00-\udfff]|\Z)').search
# How to deal with a string containing bytes before handing it to the
# application through the 'normal' interface.
def _sanitize(string):
# Turn any escaped bytes into unicode 'unknown' char.
original_bytes = string.encode('ascii', 'surrogateescape')
return original_bytes.decode('ascii', 'replace')
# Helpers
def formataddr(pair, charset='utf-8'):
"""The inverse of parseaddr(), this takes a 2-tuple of the form
(realname, email_address) and returns the string value suitable
for an RFC 2822 From, To or Cc header.
If the first element of pair is false, then the second element is
returned unmodified.
Optional charset if given is the character set that is used to encode
realname in case realname is not ASCII safe. Can be an instance of str or
a Charset-like object which has a header_encode method. Default is
'utf-8'.
"""
name, address = pair
# The address MUST (per RFC) be ascii, so raise an UnicodeError if it isn't.
address.encode('ascii')
if name:
try:
name.encode('ascii')
except UnicodeEncodeError:
if isinstance(charset, str):
charset = Charset(charset)
encoded_name = charset.header_encode(name)
return "%s <%s>" % (encoded_name, address)
else:
quotes = ''
if specialsre.search(name):
quotes = '"'
name = escapesre.sub(r'\\\g<0>', name)
return '%s%s%s <%s>' % (quotes, name, quotes, address)
return address
def getaddresses(fieldvalues):
"""Return a list of (REALNAME, EMAIL) for each fieldvalue."""
all = COMMASPACE.join(fieldvalues)
a = _AddressList(all)
return a.addresslist
ecre = re.compile(r'''
=\? # literal =?
(?P<charset>[^?]*?) # non-greedy up to the next ? is the charset
\? # literal ?
(?P<encoding>[qb]) # either a "q" or a "b", case insensitive
\? # literal ?
(?P<atom>.*?) # non-greedy up to the next ?= is the atom
\?= # literal ?=
''', re.VERBOSE | re.IGNORECASE)
def _format_timetuple_and_zone(timetuple, zone):
return '%s, %02d %s %04d %02d:%02d:%02d %s' % (
['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'][timetuple[6]],
timetuple[2],
['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'][timetuple[1] - 1],
timetuple[0], timetuple[3], timetuple[4], timetuple[5],
zone)
def formatdate(timeval=None, localtime=False, usegmt=False):
"""Returns a date string as specified by RFC 2822, e.g.:
Fri, 09 Nov 2001 01:08:47 -0000
Optional timeval if given is a floating point time value as accepted by
gmtime() and localtime(), otherwise the current time is used.
Optional localtime is a flag that when True, interprets timeval, and
returns a date relative to the local timezone instead of UTC, properly
taking daylight savings time into account.
Optional argument usegmt means that the timezone is written out as
an ascii string, not numeric one (so "GMT" instead of "+0000"). This
is needed for HTTP, and is only used when localtime==False.
"""
# Note: we cannot use strftime() because that honors the locale and RFC
# 2822 requires that day and month names be the English abbreviations.
if timeval is None:
timeval = time.time()
if localtime:
now = time.localtime(timeval)
# Calculate timezone offset, based on whether the local zone has
# daylight savings time, and whether DST is in effect.
if time.daylight and now[-1]:
offset = time.altzone
else:
offset = time.timezone
hours, minutes = divmod(abs(offset), 3600)
# Remember offset is in seconds west of UTC, but the timezone is in
# minutes east of UTC, so the signs differ.
if offset > 0:
sign = '-'
else:
sign = '+'
zone = '%s%02d%02d' % (sign, hours, minutes // 60)
else:
now = time.gmtime(timeval)
# Timezone offset is always -0000
if usegmt:
zone = 'GMT'
else:
zone = '-0000'
return _format_timetuple_and_zone(now, zone)
def format_datetime(dt, usegmt=False):
"""Turn a datetime into a date string as specified in RFC 2822.
If usegmt is True, dt must be an aware datetime with an offset of zero. In
this case 'GMT' will be rendered instead of the normal +0000 required by
RFC2822. This is to support HTTP headers involving date stamps.
"""
now = dt.timetuple()
if usegmt:
if dt.tzinfo is None or dt.tzinfo != datetime.timezone.utc:
raise ValueError("usegmt option requires a UTC datetime")
zone = 'GMT'
elif dt.tzinfo is None:
zone = '-0000'
else:
zone = dt.strftime("%z")
return _format_timetuple_and_zone(now, zone)
def make_msgid(idstring=None, domain=None):
"""Returns a string suitable for RFC 2822 compliant Message-ID, e.g:
<20020201195627.33539.96671@nightshade.la.mastaler.com>
Optional idstring if given is a string used to strengthen the
uniqueness of the message id. Optional domain if given provides the
portion of the message id after the '@'. It defaults to the locally
defined hostname.
"""
timeval = time.time()
utcdate = time.strftime('%Y%m%d%H%M%S', time.gmtime(timeval))
pid = os.getpid()
randint = random.randrange(100000)
if idstring is None:
idstring = ''
else:
idstring = '.' + idstring
if domain is None:
domain = socket.getfqdn()
msgid = '<%s.%s.%s%s@%s>' % (utcdate, pid, randint, idstring, domain)
return msgid
def parsedate_to_datetime(data):
_3to2list = list(_parsedate_tz(data))
dtuple, tz, = [_3to2list[:-1]] + _3to2list[-1:]
if tz is None:
return datetime.datetime(*dtuple[:6])
return datetime.datetime(*dtuple[:6],
tzinfo=datetime.timezone(datetime.timedelta(seconds=tz)))
def parseaddr(addr):
addrs = _AddressList(addr).addresslist
if not addrs:
return '', ''
return addrs[0]
# rfc822.unquote() doesn't properly de-backslash-ify in Python pre-2.3.
def unquote(str):
"""Remove quotes from a string."""
if len(str) > 1:
if str.startswith('"') and str.endswith('"'):
return str[1:-1].replace('\\\\', '\\').replace('\\"', '"')
if str.startswith('<') and str.endswith('>'):
return str[1:-1]
return str
# RFC2231-related functions - parameter encoding and decoding
def decode_rfc2231(s):
"""Decode string according to RFC 2231"""
parts = s.split(TICK, 2)
if len(parts) <= 2:
return None, None, s
return parts
def encode_rfc2231(s, charset=None, language=None):
"""Encode string according to RFC 2231.
If neither charset nor language is given, then s is returned as-is. If
charset is given but not language, the string is encoded using the empty
string for language.
"""
s = url_quote(s, safe='', encoding=charset or 'ascii')
if charset is None and language is None:
return s
if language is None:
language = ''
return "%s'%s'%s" % (charset, language, s)
rfc2231_continuation = re.compile(r'^(?P<name>\w+)\*((?P<num>[0-9]+)\*?)?$',
re.ASCII)
def decode_params(params):
"""Decode parameters list according to RFC 2231.
params is a sequence of 2-tuples containing (param name, string value).
"""
# Copy params so we don't mess with the original
params = params[:]
new_params = []
# Map parameter's name to a list of continuations. The values are a
# 3-tuple of the continuation number, the string value, and a flag
# specifying whether a particular segment is %-encoded.
rfc2231_params = {}
name, value = params.pop(0)
new_params.append((name, value))
while params:
name, value = params.pop(0)
if name.endswith('*'):
encoded = True
else:
encoded = False
value = unquote(value)
mo = rfc2231_continuation.match(name)
if mo:
name, num = mo.group('name', 'num')
if num is not None:
num = int(num)
rfc2231_params.setdefault(name, []).append((num, value, encoded))
else:
new_params.append((name, '"%s"' % quote(value)))
if rfc2231_params:
for name, continuations in rfc2231_params.items():
value = []
extended = False
# Sort by number
continuations.sort()
# And now append all values in numerical order, converting
# %-encodings for the encoded segments. If any of the
# continuation names ends in a *, then the entire string, after
# decoding segments and concatenating, must have the charset and
# language specifiers at the beginning of the string.
for num, s, encoded in continuations:
if encoded:
# Decode as "latin-1", so the characters in s directly
# represent the percent-encoded octet values.
# collapse_rfc2231_value treats this as an octet sequence.
s = url_unquote(s, encoding="latin-1")
extended = True
value.append(s)
value = quote(EMPTYSTRING.join(value))
if extended:
charset, language, value = decode_rfc2231(value)
new_params.append((name, (charset, language, '"%s"' % value)))
else:
new_params.append((name, '"%s"' % value))
return new_params
def collapse_rfc2231_value(value, errors='replace',
fallback_charset='us-ascii'):
if not isinstance(value, tuple) or len(value) != 3:
return unquote(value)
# While value comes to us as a unicode string, we need it to be a bytes
# object. We do not want bytes() normal utf-8 decoder, we want a straight
# interpretation of the string as character bytes.
charset, language, text = value
rawbytes = bytes(text, 'raw-unicode-escape')
try:
return str(rawbytes, charset, errors)
except LookupError:
# charset is not a known codec.
return unquote(text)
#
# datetime doesn't provide a localtime function yet, so provide one. Code
# adapted from the patch in issue 9527. This may not be perfect, but it is
# better than not having it.
#
def localtime(dt=None, isdst=-1):
"""Return local time as an aware datetime object.
If called without arguments, return current time. Otherwise *dt*
argument should be a datetime instance, and it is converted to the
local time zone according to the system time zone database. If *dt* is
naive (that is, dt.tzinfo is None), it is assumed to be in local time.
In this case, a positive or zero value for *isdst* causes localtime to
presume initially that summer time (for example, Daylight Saving Time)
is or is not (respectively) in effect for the specified time. A
negative value for *isdst* causes the localtime() function to attempt
to divine whether summer time is in effect for the specified time.
"""
if dt is None:
return datetime.datetime.now(datetime.timezone.utc).astimezone()
if dt.tzinfo is not None:
return dt.astimezone()
# We have a naive datetime. Convert to a (localtime) timetuple and pass to
# system mktime together with the isdst hint. System mktime will return
# seconds since epoch.
tm = dt.timetuple()[:-1] + (isdst,)
seconds = time.mktime(tm)
localtm = time.localtime(seconds)
try:
delta = datetime.timedelta(seconds=localtm.tm_gmtoff)
tz = datetime.timezone(delta, localtm.tm_zone)
except AttributeError:
# Compute UTC offset and compare with the value implied by tm_isdst.
# If the values match, use the zone name implied by tm_isdst.
delta = dt - datetime.datetime(*time.gmtime(seconds)[:6])
dst = time.daylight and localtm.tm_isdst > 0
gmtoff = -(time.altzone if dst else time.timezone)
if delta == datetime.timedelta(seconds=gmtoff):
tz = datetime.timezone(delta, time.tzname[dst])
else:
tz = datetime.timezone(delta)
return dt.replace(tzinfo=tz)

View file

@ -0,0 +1,27 @@
"""
General functions for HTML manipulation, backported from Py3.
Note that this uses Python 2.7 code with the corresponding Python 3
module names and locations.
"""
from __future__ import unicode_literals
_escape_map = {ord('&'): '&amp;', ord('<'): '&lt;', ord('>'): '&gt;'}
_escape_map_full = {ord('&'): '&amp;', ord('<'): '&lt;', ord('>'): '&gt;',
ord('"'): '&quot;', ord('\''): '&#x27;'}
# NB: this is a candidate for a bytes/string polymorphic interface
def escape(s, quote=True):
"""
Replace special characters "&", "<" and ">" to HTML-safe sequences.
If the optional flag quote is true (the default), the quotation mark
characters, both double quote (") and single quote (') characters are also
translated.
"""
assert not isinstance(s, bytes), 'Pass a unicode string'
if quote:
return s.translate(_escape_map_full)
return s.translate(_escape_map)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,536 @@
"""A parser for HTML and XHTML.
Backported for python-future from Python 3.3.
"""
# This file is based on sgmllib.py, but the API is slightly different.
# XXX There should be a way to distinguish between PCDATA (parsed
# character data -- the normal case), RCDATA (replaceable character
# data -- only char and entity references and end tags are special)
# and CDATA (character data -- only end tags are special).
from __future__ import (absolute_import, division,
print_function, unicode_literals)
from future.builtins import *
from future.backports import _markupbase
import re
import warnings
# Regular expressions used for parsing
interesting_normal = re.compile('[&<]')
incomplete = re.compile('&[a-zA-Z#]')
entityref = re.compile('&([a-zA-Z][-.a-zA-Z0-9]*)[^a-zA-Z0-9]')
charref = re.compile('&#(?:[0-9]+|[xX][0-9a-fA-F]+)[^0-9a-fA-F]')
starttagopen = re.compile('<[a-zA-Z]')
piclose = re.compile('>')
commentclose = re.compile(r'--\s*>')
tagfind = re.compile('([a-zA-Z][-.a-zA-Z0-9:_]*)(?:\s|/(?!>))*')
# see http://www.w3.org/TR/html5/tokenization.html#tag-open-state
# and http://www.w3.org/TR/html5/tokenization.html#tag-name-state
tagfind_tolerant = re.compile('[a-zA-Z][^\t\n\r\f />\x00]*')
# Note:
# 1) the strict attrfind isn't really strict, but we can't make it
# correctly strict without breaking backward compatibility;
# 2) if you change attrfind remember to update locatestarttagend too;
# 3) if you change attrfind and/or locatestarttagend the parser will
# explode, so don't do it.
attrfind = re.compile(
r'\s*([a-zA-Z_][-.:a-zA-Z_0-9]*)(\s*=\s*'
r'(\'[^\']*\'|"[^"]*"|[^\s"\'=<>`]*))?')
attrfind_tolerant = re.compile(
r'((?<=[\'"\s/])[^\s/>][^\s/=>]*)(\s*=+\s*'
r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?(?:\s|/(?!>))*')
locatestarttagend = re.compile(r"""
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
(?:\s+ # whitespace before attribute name
(?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
(?:\s*=\s* # value indicator
(?:'[^']*' # LITA-enclosed value
|\"[^\"]*\" # LIT-enclosed value
|[^'\">\s]+ # bare value
)
)?
)
)*
\s* # trailing whitespace
""", re.VERBOSE)
locatestarttagend_tolerant = re.compile(r"""
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
(?:[\s/]* # optional whitespace before attribute name
(?:(?<=['"\s/])[^\s/>][^\s/=>]* # attribute name
(?:\s*=+\s* # value indicator
(?:'[^']*' # LITA-enclosed value
|"[^"]*" # LIT-enclosed value
|(?!['"])[^>\s]* # bare value
)
(?:\s*,)* # possibly followed by a comma
)?(?:\s|/(?!>))*
)*
)?
\s* # trailing whitespace
""", re.VERBOSE)
endendtag = re.compile('>')
# the HTML 5 spec, section 8.1.2.2, doesn't allow spaces between
# </ and the tag name, so maybe this should be fixed
endtagfind = re.compile('</\s*([a-zA-Z][-.a-zA-Z0-9:_]*)\s*>')
class HTMLParseError(Exception):
"""Exception raised for all parse errors."""
def __init__(self, msg, position=(None, None)):
assert msg
self.msg = msg
self.lineno = position[0]
self.offset = position[1]
def __str__(self):
result = self.msg
if self.lineno is not None:
result = result + ", at line %d" % self.lineno
if self.offset is not None:
result = result + ", column %d" % (self.offset + 1)
return result
class HTMLParser(_markupbase.ParserBase):
"""Find tags and other markup and call handler functions.
Usage:
p = HTMLParser()
p.feed(data)
...
p.close()
Start tags are handled by calling self.handle_starttag() or
self.handle_startendtag(); end tags by self.handle_endtag(). The
data between tags is passed from the parser to the derived class
by calling self.handle_data() with the data as argument (the data
may be split up in arbitrary chunks). Entity references are
passed by calling self.handle_entityref() with the entity
reference as the argument. Numeric character references are
passed to self.handle_charref() with the string containing the
reference as the argument.
"""
CDATA_CONTENT_ELEMENTS = ("script", "style")
def __init__(self, strict=False):
"""Initialize and reset this instance.
If strict is set to False (the default) the parser will parse invalid
markup, otherwise it will raise an error. Note that the strict mode
is deprecated.
"""
if strict:
warnings.warn("The strict mode is deprecated.",
DeprecationWarning, stacklevel=2)
self.strict = strict
self.reset()
def reset(self):
"""Reset this instance. Loses all unprocessed data."""
self.rawdata = ''
self.lasttag = '???'
self.interesting = interesting_normal
self.cdata_elem = None
_markupbase.ParserBase.reset(self)
def feed(self, data):
r"""Feed data to the parser.
Call this as often as you want, with as little or as much text
as you want (may include '\n').
"""
self.rawdata = self.rawdata + data
self.goahead(0)
def close(self):
"""Handle any buffered data."""
self.goahead(1)
def error(self, message):
raise HTMLParseError(message, self.getpos())
__starttag_text = None
def get_starttag_text(self):
"""Return full source of start tag: '<...>'."""
return self.__starttag_text
def set_cdata_mode(self, elem):
self.cdata_elem = elem.lower()
self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I)
def clear_cdata_mode(self):
self.interesting = interesting_normal
self.cdata_elem = None
# Internal -- handle data as far as reasonable. May leave state
# and data to be processed by a subsequent call. If 'end' is
# true, force handling all data as if followed by EOF marker.
def goahead(self, end):
rawdata = self.rawdata
i = 0
n = len(rawdata)
while i < n:
match = self.interesting.search(rawdata, i) # < or &
if match:
j = match.start()
else:
if self.cdata_elem:
break
j = n
if i < j: self.handle_data(rawdata[i:j])
i = self.updatepos(i, j)
if i == n: break
startswith = rawdata.startswith
if startswith('<', i):
if starttagopen.match(rawdata, i): # < + letter
k = self.parse_starttag(i)
elif startswith("</", i):
k = self.parse_endtag(i)
elif startswith("<!--", i):
k = self.parse_comment(i)
elif startswith("<?", i):
k = self.parse_pi(i)
elif startswith("<!", i):
if self.strict:
k = self.parse_declaration(i)
else:
k = self.parse_html_declaration(i)
elif (i + 1) < n:
self.handle_data("<")
k = i + 1
else:
break
if k < 0:
if not end:
break
if self.strict:
self.error("EOF in middle of construct")
k = rawdata.find('>', i + 1)
if k < 0:
k = rawdata.find('<', i + 1)
if k < 0:
k = i + 1
else:
k += 1
self.handle_data(rawdata[i:k])
i = self.updatepos(i, k)
elif startswith("&#", i):
match = charref.match(rawdata, i)
if match:
name = match.group()[2:-1]
self.handle_charref(name)
k = match.end()
if not startswith(';', k-1):
k = k - 1
i = self.updatepos(i, k)
continue
else:
if ";" in rawdata[i:]: #bail by consuming &#
self.handle_data(rawdata[0:2])
i = self.updatepos(i, 2)
break
elif startswith('&', i):
match = entityref.match(rawdata, i)
if match:
name = match.group(1)
self.handle_entityref(name)
k = match.end()
if not startswith(';', k-1):
k = k - 1
i = self.updatepos(i, k)
continue
match = incomplete.match(rawdata, i)
if match:
# match.group() will contain at least 2 chars
if end and match.group() == rawdata[i:]:
if self.strict:
self.error("EOF in middle of entity or char ref")
else:
if k <= i:
k = n
i = self.updatepos(i, i + 1)
# incomplete
break
elif (i + 1) < n:
# not the end of the buffer, and can't be confused
# with some other construct
self.handle_data("&")
i = self.updatepos(i, i + 1)
else:
break
else:
assert 0, "interesting.search() lied"
# end while
if end and i < n and not self.cdata_elem:
self.handle_data(rawdata[i:n])
i = self.updatepos(i, n)
self.rawdata = rawdata[i:]
# Internal -- parse html declarations, return length or -1 if not terminated
# See w3.org/TR/html5/tokenization.html#markup-declaration-open-state
# See also parse_declaration in _markupbase
def parse_html_declaration(self, i):
rawdata = self.rawdata
assert rawdata[i:i+2] == '<!', ('unexpected call to '
'parse_html_declaration()')
if rawdata[i:i+4] == '<!--':
# this case is actually already handled in goahead()
return self.parse_comment(i)
elif rawdata[i:i+3] == '<![':
return self.parse_marked_section(i)
elif rawdata[i:i+9].lower() == '<!doctype':
# find the closing >
gtpos = rawdata.find('>', i+9)
if gtpos == -1:
return -1
self.handle_decl(rawdata[i+2:gtpos])
return gtpos+1
else:
return self.parse_bogus_comment(i)
# Internal -- parse bogus comment, return length or -1 if not terminated
# see http://www.w3.org/TR/html5/tokenization.html#bogus-comment-state
def parse_bogus_comment(self, i, report=1):
rawdata = self.rawdata
assert rawdata[i:i+2] in ('<!', '</'), ('unexpected call to '
'parse_comment()')
pos = rawdata.find('>', i+2)
if pos == -1:
return -1
if report:
self.handle_comment(rawdata[i+2:pos])
return pos + 1
# Internal -- parse processing instr, return end or -1 if not terminated
def parse_pi(self, i):
rawdata = self.rawdata
assert rawdata[i:i+2] == '<?', 'unexpected call to parse_pi()'
match = piclose.search(rawdata, i+2) # >
if not match:
return -1
j = match.start()
self.handle_pi(rawdata[i+2: j])
j = match.end()
return j
# Internal -- handle starttag, return end or -1 if not terminated
def parse_starttag(self, i):
self.__starttag_text = None
endpos = self.check_for_whole_start_tag(i)
if endpos < 0:
return endpos
rawdata = self.rawdata
self.__starttag_text = rawdata[i:endpos]
# Now parse the data between i+1 and j into a tag and attrs
attrs = []
match = tagfind.match(rawdata, i+1)
assert match, 'unexpected call to parse_starttag()'
k = match.end()
self.lasttag = tag = match.group(1).lower()
while k < endpos:
if self.strict:
m = attrfind.match(rawdata, k)
else:
m = attrfind_tolerant.match(rawdata, k)
if not m:
break
attrname, rest, attrvalue = m.group(1, 2, 3)
if not rest:
attrvalue = None
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
attrvalue[:1] == '"' == attrvalue[-1:]:
attrvalue = attrvalue[1:-1]
if attrvalue:
attrvalue = self.unescape(attrvalue)
attrs.append((attrname.lower(), attrvalue))
k = m.end()
end = rawdata[k:endpos].strip()
if end not in (">", "/>"):
lineno, offset = self.getpos()
if "\n" in self.__starttag_text:
lineno = lineno + self.__starttag_text.count("\n")
offset = len(self.__starttag_text) \
- self.__starttag_text.rfind("\n")
else:
offset = offset + len(self.__starttag_text)
if self.strict:
self.error("junk characters in start tag: %r"
% (rawdata[k:endpos][:20],))
self.handle_data(rawdata[i:endpos])
return endpos
if end.endswith('/>'):
# XHTML-style empty tag: <span attr="value" />
self.handle_startendtag(tag, attrs)
else:
self.handle_starttag(tag, attrs)
if tag in self.CDATA_CONTENT_ELEMENTS:
self.set_cdata_mode(tag)
return endpos
# Internal -- check to see if we have a complete starttag; return end
# or -1 if incomplete.
def check_for_whole_start_tag(self, i):
rawdata = self.rawdata
if self.strict:
m = locatestarttagend.match(rawdata, i)
else:
m = locatestarttagend_tolerant.match(rawdata, i)
if m:
j = m.end()
next = rawdata[j:j+1]
if next == ">":
return j + 1
if next == "/":
if rawdata.startswith("/>", j):
return j + 2
if rawdata.startswith("/", j):
# buffer boundary
return -1
# else bogus input
if self.strict:
self.updatepos(i, j + 1)
self.error("malformed empty start tag")
if j > i:
return j
else:
return i + 1
if next == "":
# end of input
return -1
if next in ("abcdefghijklmnopqrstuvwxyz=/"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"):
# end of input in or before attribute value, or we have the
# '/' from a '/>' ending
return -1
if self.strict:
self.updatepos(i, j)
self.error("malformed start tag")
if j > i:
return j
else:
return i + 1
raise AssertionError("we should not get here!")
# Internal -- parse endtag, return end or -1 if incomplete
def parse_endtag(self, i):
rawdata = self.rawdata
assert rawdata[i:i+2] == "</", "unexpected call to parse_endtag"
match = endendtag.search(rawdata, i+1) # >
if not match:
return -1
gtpos = match.end()
match = endtagfind.match(rawdata, i) # </ + tag + >
if not match:
if self.cdata_elem is not None:
self.handle_data(rawdata[i:gtpos])
return gtpos
if self.strict:
self.error("bad end tag: %r" % (rawdata[i:gtpos],))
# find the name: w3.org/TR/html5/tokenization.html#tag-name-state
namematch = tagfind_tolerant.match(rawdata, i+2)
if not namematch:
# w3.org/TR/html5/tokenization.html#end-tag-open-state
if rawdata[i:i+3] == '</>':
return i+3
else:
return self.parse_bogus_comment(i)
tagname = namematch.group().lower()
# consume and ignore other stuff between the name and the >
# Note: this is not 100% correct, since we might have things like
# </tag attr=">">, but looking for > after tha name should cover
# most of the cases and is much simpler
gtpos = rawdata.find('>', namematch.end())
self.handle_endtag(tagname)
return gtpos+1
elem = match.group(1).lower() # script or style
if self.cdata_elem is not None:
if elem != self.cdata_elem:
self.handle_data(rawdata[i:gtpos])
return gtpos
self.handle_endtag(elem.lower())
self.clear_cdata_mode()
return gtpos
# Overridable -- finish processing of start+end tag: <tag.../>
def handle_startendtag(self, tag, attrs):
self.handle_starttag(tag, attrs)
self.handle_endtag(tag)
# Overridable -- handle start tag
def handle_starttag(self, tag, attrs):
pass
# Overridable -- handle end tag
def handle_endtag(self, tag):
pass
# Overridable -- handle character reference
def handle_charref(self, name):
pass
# Overridable -- handle entity reference
def handle_entityref(self, name):
pass
# Overridable -- handle data
def handle_data(self, data):
pass
# Overridable -- handle comment
def handle_comment(self, data):
pass
# Overridable -- handle declaration
def handle_decl(self, decl):
pass
# Overridable -- handle processing instruction
def handle_pi(self, data):
pass
def unknown_decl(self, data):
if self.strict:
self.error("unknown declaration: %r" % (data,))
# Internal -- helper to remove special character quoting
def unescape(self, s):
if '&' not in s:
return s
def replaceEntities(s):
s = s.groups()[0]
try:
if s[0] == "#":
s = s[1:]
if s[0] in ['x','X']:
c = int(s[1:].rstrip(';'), 16)
else:
c = int(s.rstrip(';'))
return chr(c)
except ValueError:
return '&#' + s
else:
from future.backports.html.entities import html5
if s in html5:
return html5[s]
elif s.endswith(';'):
return '&' + s
for x in range(2, len(s)):
if s[:x] in html5:
return html5[s[:x]] + s[x:]
else:
return '&' + s
return re.sub(r"&(#?[xX]?(?:[0-9a-fA-F]+;|\w{1,32};?))",
replaceEntities, s)

View file

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,597 @@
####
# Copyright 2000 by Timothy O'Malley <timo@alum.mit.edu>
#
# All Rights Reserved
#
# Permission to use, copy, modify, and distribute this software
# and its documentation for any purpose and without fee is hereby
# granted, provided that the above copyright notice appear in all
# copies and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of
# Timothy O'Malley not be used in advertising or publicity
# pertaining to distribution of the software without specific, written
# prior permission.
#
# Timothy O'Malley DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
# SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS, IN NO EVENT SHALL Timothy O'Malley BE LIABLE FOR
# ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.
#
####
#
# Id: Cookie.py,v 2.29 2000/08/23 05:28:49 timo Exp
# by Timothy O'Malley <timo@alum.mit.edu>
#
# Cookie.py is a Python module for the handling of HTTP
# cookies as a Python dictionary. See RFC 2109 for more
# information on cookies.
#
# The original idea to treat Cookies as a dictionary came from
# Dave Mitchell (davem@magnet.com) in 1995, when he released the
# first version of nscookie.py.
#
####
r"""
http.cookies module ported to python-future from Py3.3
Here's a sample session to show how to use this module.
At the moment, this is the only documentation.
The Basics
----------
Importing is easy...
>>> from http import cookies
Most of the time you start by creating a cookie.
>>> C = cookies.SimpleCookie()
Once you've created your Cookie, you can add values just as if it were
a dictionary.
>>> C = cookies.SimpleCookie()
>>> C["fig"] = "newton"
>>> C["sugar"] = "wafer"
>>> C.output()
'Set-Cookie: fig=newton\r\nSet-Cookie: sugar=wafer'
Notice that the printable representation of a Cookie is the
appropriate format for a Set-Cookie: header. This is the
default behavior. You can change the header and printed
attributes by using the .output() function
>>> C = cookies.SimpleCookie()
>>> C["rocky"] = "road"
>>> C["rocky"]["path"] = "/cookie"
>>> print(C.output(header="Cookie:"))
Cookie: rocky=road; Path=/cookie
>>> print(C.output(attrs=[], header="Cookie:"))
Cookie: rocky=road
The load() method of a Cookie extracts cookies from a string. In a
CGI script, you would use this method to extract the cookies from the
HTTP_COOKIE environment variable.
>>> C = cookies.SimpleCookie()
>>> C.load("chips=ahoy; vienna=finger")
>>> C.output()
'Set-Cookie: chips=ahoy\r\nSet-Cookie: vienna=finger'
The load() method is darn-tootin smart about identifying cookies
within a string. Escaped quotation marks, nested semicolons, and other
such trickeries do not confuse it.
>>> C = cookies.SimpleCookie()
>>> C.load('keebler="E=everybody; L=\\"Loves\\"; fudge=\\012;";')
>>> print(C)
Set-Cookie: keebler="E=everybody; L=\"Loves\"; fudge=\012;"
Each element of the Cookie also supports all of the RFC 2109
Cookie attributes. Here's an example which sets the Path
attribute.
>>> C = cookies.SimpleCookie()
>>> C["oreo"] = "doublestuff"
>>> C["oreo"]["path"] = "/"
>>> print(C)
Set-Cookie: oreo=doublestuff; Path=/
Each dictionary element has a 'value' attribute, which gives you
back the value associated with the key.
>>> C = cookies.SimpleCookie()
>>> C["twix"] = "none for you"
>>> C["twix"].value
'none for you'
The SimpleCookie expects that all values should be standard strings.
Just to be sure, SimpleCookie invokes the str() builtin to convert
the value to a string, when the values are set dictionary-style.
>>> C = cookies.SimpleCookie()
>>> C["number"] = 7
>>> C["string"] = "seven"
>>> C["number"].value
'7'
>>> C["string"].value
'seven'
>>> C.output()
'Set-Cookie: number=7\r\nSet-Cookie: string=seven'
Finis.
"""
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from future.builtins import chr, dict, int, str
from future.utils import PY2, as_native_str
#
# Import our required modules
#
import re
re.ASCII = 0 # for py2 compatibility
import string
__all__ = ["CookieError", "BaseCookie", "SimpleCookie"]
_nulljoin = ''.join
_semispacejoin = '; '.join
_spacejoin = ' '.join
#
# Define an exception visible to External modules
#
class CookieError(Exception):
pass
# These quoting routines conform to the RFC2109 specification, which in
# turn references the character definitions from RFC2068. They provide
# a two-way quoting algorithm. Any non-text character is translated
# into a 4 character sequence: a forward-slash followed by the
# three-digit octal equivalent of the character. Any '\' or '"' is
# quoted with a preceeding '\' slash.
#
# These are taken from RFC2068 and RFC2109.
# _LegalChars is the list of chars which don't require "'s
# _Translator hash-table for fast quoting
#
_LegalChars = string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~:"
_Translator = {
'\000' : '\\000', '\001' : '\\001', '\002' : '\\002',
'\003' : '\\003', '\004' : '\\004', '\005' : '\\005',
'\006' : '\\006', '\007' : '\\007', '\010' : '\\010',
'\011' : '\\011', '\012' : '\\012', '\013' : '\\013',
'\014' : '\\014', '\015' : '\\015', '\016' : '\\016',
'\017' : '\\017', '\020' : '\\020', '\021' : '\\021',
'\022' : '\\022', '\023' : '\\023', '\024' : '\\024',
'\025' : '\\025', '\026' : '\\026', '\027' : '\\027',
'\030' : '\\030', '\031' : '\\031', '\032' : '\\032',
'\033' : '\\033', '\034' : '\\034', '\035' : '\\035',
'\036' : '\\036', '\037' : '\\037',
# Because of the way browsers really handle cookies (as opposed
# to what the RFC says) we also encode , and ;
',' : '\\054', ';' : '\\073',
'"' : '\\"', '\\' : '\\\\',
'\177' : '\\177', '\200' : '\\200', '\201' : '\\201',
'\202' : '\\202', '\203' : '\\203', '\204' : '\\204',
'\205' : '\\205', '\206' : '\\206', '\207' : '\\207',
'\210' : '\\210', '\211' : '\\211', '\212' : '\\212',
'\213' : '\\213', '\214' : '\\214', '\215' : '\\215',
'\216' : '\\216', '\217' : '\\217', '\220' : '\\220',
'\221' : '\\221', '\222' : '\\222', '\223' : '\\223',
'\224' : '\\224', '\225' : '\\225', '\226' : '\\226',
'\227' : '\\227', '\230' : '\\230', '\231' : '\\231',
'\232' : '\\232', '\233' : '\\233', '\234' : '\\234',
'\235' : '\\235', '\236' : '\\236', '\237' : '\\237',
'\240' : '\\240', '\241' : '\\241', '\242' : '\\242',
'\243' : '\\243', '\244' : '\\244', '\245' : '\\245',
'\246' : '\\246', '\247' : '\\247', '\250' : '\\250',
'\251' : '\\251', '\252' : '\\252', '\253' : '\\253',
'\254' : '\\254', '\255' : '\\255', '\256' : '\\256',
'\257' : '\\257', '\260' : '\\260', '\261' : '\\261',
'\262' : '\\262', '\263' : '\\263', '\264' : '\\264',
'\265' : '\\265', '\266' : '\\266', '\267' : '\\267',
'\270' : '\\270', '\271' : '\\271', '\272' : '\\272',
'\273' : '\\273', '\274' : '\\274', '\275' : '\\275',
'\276' : '\\276', '\277' : '\\277', '\300' : '\\300',
'\301' : '\\301', '\302' : '\\302', '\303' : '\\303',
'\304' : '\\304', '\305' : '\\305', '\306' : '\\306',
'\307' : '\\307', '\310' : '\\310', '\311' : '\\311',
'\312' : '\\312', '\313' : '\\313', '\314' : '\\314',
'\315' : '\\315', '\316' : '\\316', '\317' : '\\317',
'\320' : '\\320', '\321' : '\\321', '\322' : '\\322',
'\323' : '\\323', '\324' : '\\324', '\325' : '\\325',
'\326' : '\\326', '\327' : '\\327', '\330' : '\\330',
'\331' : '\\331', '\332' : '\\332', '\333' : '\\333',
'\334' : '\\334', '\335' : '\\335', '\336' : '\\336',
'\337' : '\\337', '\340' : '\\340', '\341' : '\\341',
'\342' : '\\342', '\343' : '\\343', '\344' : '\\344',
'\345' : '\\345', '\346' : '\\346', '\347' : '\\347',
'\350' : '\\350', '\351' : '\\351', '\352' : '\\352',
'\353' : '\\353', '\354' : '\\354', '\355' : '\\355',
'\356' : '\\356', '\357' : '\\357', '\360' : '\\360',
'\361' : '\\361', '\362' : '\\362', '\363' : '\\363',
'\364' : '\\364', '\365' : '\\365', '\366' : '\\366',
'\367' : '\\367', '\370' : '\\370', '\371' : '\\371',
'\372' : '\\372', '\373' : '\\373', '\374' : '\\374',
'\375' : '\\375', '\376' : '\\376', '\377' : '\\377'
}
def _quote(str, LegalChars=_LegalChars):
r"""Quote a string for use in a cookie header.
If the string does not need to be double-quoted, then just return the
string. Otherwise, surround the string in doublequotes and quote
(with a \) special characters.
"""
if all(c in LegalChars for c in str):
return str
else:
return '"' + _nulljoin(_Translator.get(s, s) for s in str) + '"'
_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]")
_QuotePatt = re.compile(r"[\\].")
def _unquote(mystr):
# If there aren't any doublequotes,
# then there can't be any special characters. See RFC 2109.
if len(mystr) < 2:
return mystr
if mystr[0] != '"' or mystr[-1] != '"':
return mystr
# We have to assume that we must decode this string.
# Down to work.
# Remove the "s
mystr = mystr[1:-1]
# Check for special sequences. Examples:
# \012 --> \n
# \" --> "
#
i = 0
n = len(mystr)
res = []
while 0 <= i < n:
o_match = _OctalPatt.search(mystr, i)
q_match = _QuotePatt.search(mystr, i)
if not o_match and not q_match: # Neither matched
res.append(mystr[i:])
break
# else:
j = k = -1
if o_match:
j = o_match.start(0)
if q_match:
k = q_match.start(0)
if q_match and (not o_match or k < j): # QuotePatt matched
res.append(mystr[i:k])
res.append(mystr[k+1])
i = k + 2
else: # OctalPatt matched
res.append(mystr[i:j])
res.append(chr(int(mystr[j+1:j+4], 8)))
i = j + 4
return _nulljoin(res)
# The _getdate() routine is used to set the expiration time in the cookie's HTTP
# header. By default, _getdate() returns the current time in the appropriate
# "expires" format for a Set-Cookie header. The one optional argument is an
# offset from now, in seconds. For example, an offset of -3600 means "one hour
# ago". The offset may be a floating point number.
#
_weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
_monthname = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
def _getdate(future=0, weekdayname=_weekdayname, monthname=_monthname):
from time import gmtime, time
now = time()
year, month, day, hh, mm, ss, wd, y, z = gmtime(now + future)
return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % \
(weekdayname[wd], day, monthname[month], year, hh, mm, ss)
class Morsel(dict):
"""A class to hold ONE (key, value) pair.
In a cookie, each such pair may have several attributes, so this class is
used to keep the attributes associated with the appropriate key,value pair.
This class also includes a coded_value attribute, which is used to hold
the network representation of the value. This is most useful when Python
objects are pickled for network transit.
"""
# RFC 2109 lists these attributes as reserved:
# path comment domain
# max-age secure version
#
# For historical reasons, these attributes are also reserved:
# expires
#
# This is an extension from Microsoft:
# httponly
#
# This dictionary provides a mapping from the lowercase
# variant on the left to the appropriate traditional
# formatting on the right.
_reserved = {
"expires" : "expires",
"path" : "Path",
"comment" : "Comment",
"domain" : "Domain",
"max-age" : "Max-Age",
"secure" : "secure",
"httponly" : "httponly",
"version" : "Version",
}
_flags = set(['secure', 'httponly'])
def __init__(self):
# Set defaults
self.key = self.value = self.coded_value = None
# Set default attributes
for key in self._reserved:
dict.__setitem__(self, key, "")
def __setitem__(self, K, V):
K = K.lower()
if not K in self._reserved:
raise CookieError("Invalid Attribute %s" % K)
dict.__setitem__(self, K, V)
def isReservedKey(self, K):
return K.lower() in self._reserved
def set(self, key, val, coded_val, LegalChars=_LegalChars):
# First we verify that the key isn't a reserved word
# Second we make sure it only contains legal characters
if key.lower() in self._reserved:
raise CookieError("Attempt to set a reserved key: %s" % key)
if any(c not in LegalChars for c in key):
raise CookieError("Illegal key value: %s" % key)
# It's a good key, so save it.
self.key = key
self.value = val
self.coded_value = coded_val
def output(self, attrs=None, header="Set-Cookie:"):
return "%s %s" % (header, self.OutputString(attrs))
__str__ = output
@as_native_str()
def __repr__(self):
if PY2 and isinstance(self.value, unicode):
val = str(self.value) # make it a newstr to remove the u prefix
else:
val = self.value
return '<%s: %s=%s>' % (self.__class__.__name__,
str(self.key), repr(val))
def js_output(self, attrs=None):
# Print javascript
return """
<script type="text/javascript">
<!-- begin hiding
document.cookie = \"%s\";
// end hiding -->
</script>
""" % (self.OutputString(attrs).replace('"', r'\"'))
def OutputString(self, attrs=None):
# Build up our result
#
result = []
append = result.append
# First, the key=value pair
append("%s=%s" % (self.key, self.coded_value))
# Now add any defined attributes
if attrs is None:
attrs = self._reserved
items = sorted(self.items())
for key, value in items:
if value == "":
continue
if key not in attrs:
continue
if key == "expires" and isinstance(value, int):
append("%s=%s" % (self._reserved[key], _getdate(value)))
elif key == "max-age" and isinstance(value, int):
append("%s=%d" % (self._reserved[key], value))
elif key == "secure":
append(str(self._reserved[key]))
elif key == "httponly":
append(str(self._reserved[key]))
else:
append("%s=%s" % (self._reserved[key], value))
# Return the result
return _semispacejoin(result)
#
# Pattern for finding cookie
#
# This used to be strict parsing based on the RFC2109 and RFC2068
# specifications. I have since discovered that MSIE 3.0x doesn't
# follow the character rules outlined in those specs. As a
# result, the parsing rules here are less strict.
#
_LegalCharsPatt = r"[\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]"
_CookiePattern = re.compile(r"""
(?x) # This is a verbose pattern
(?P<key> # Start of group 'key'
""" + _LegalCharsPatt + r"""+? # Any word of at least one letter
) # End of group 'key'
( # Optional group: there may not be a value.
\s*=\s* # Equal Sign
(?P<val> # Start of group 'val'
"(?:[^\\"]|\\.)*" # Any doublequoted string
| # or
\w{3},\s[\w\d\s-]{9,11}\s[\d:]{8}\sGMT # Special case for "expires" attr
| # or
""" + _LegalCharsPatt + r"""* # Any word or empty string
) # End of group 'val'
)? # End of optional value group
\s* # Any number of spaces.
(\s+|;|$) # Ending either at space, semicolon, or EOS.
""", re.ASCII) # May be removed if safe.
# At long last, here is the cookie class. Using this class is almost just like
# using a dictionary. See this module's docstring for example usage.
#
class BaseCookie(dict):
"""A container class for a set of Morsels."""
def value_decode(self, val):
"""real_value, coded_value = value_decode(STRING)
Called prior to setting a cookie's value from the network
representation. The VALUE is the value read from HTTP
header.
Override this function to modify the behavior of cookies.
"""
return val, val
def value_encode(self, val):
"""real_value, coded_value = value_encode(VALUE)
Called prior to setting a cookie's value from the dictionary
representation. The VALUE is the value being assigned.
Override this function to modify the behavior of cookies.
"""
strval = str(val)
return strval, strval
def __init__(self, input=None):
if input:
self.load(input)
def __set(self, key, real_value, coded_value):
"""Private method for setting a cookie's value"""
M = self.get(key, Morsel())
M.set(key, real_value, coded_value)
dict.__setitem__(self, key, M)
def __setitem__(self, key, value):
"""Dictionary style assignment."""
rval, cval = self.value_encode(value)
self.__set(key, rval, cval)
def output(self, attrs=None, header="Set-Cookie:", sep="\015\012"):
"""Return a string suitable for HTTP."""
result = []
items = sorted(self.items())
for key, value in items:
result.append(value.output(attrs, header))
return sep.join(result)
__str__ = output
@as_native_str()
def __repr__(self):
l = []
items = sorted(self.items())
for key, value in items:
if PY2 and isinstance(value.value, unicode):
val = str(value.value) # make it a newstr to remove the u prefix
else:
val = value.value
l.append('%s=%s' % (str(key), repr(val)))
return '<%s: %s>' % (self.__class__.__name__, _spacejoin(l))
def js_output(self, attrs=None):
"""Return a string suitable for JavaScript."""
result = []
items = sorted(self.items())
for key, value in items:
result.append(value.js_output(attrs))
return _nulljoin(result)
def load(self, rawdata):
"""Load cookies from a string (presumably HTTP_COOKIE) or
from a dictionary. Loading cookies from a dictionary 'd'
is equivalent to calling:
map(Cookie.__setitem__, d.keys(), d.values())
"""
if isinstance(rawdata, str):
self.__parse_string(rawdata)
else:
# self.update() wouldn't call our custom __setitem__
for key, value in rawdata.items():
self[key] = value
return
def __parse_string(self, mystr, patt=_CookiePattern):
i = 0 # Our starting point
n = len(mystr) # Length of string
M = None # current morsel
while 0 <= i < n:
# Start looking for a cookie
match = patt.search(mystr, i)
if not match:
# No more cookies
break
key, value = match.group("key"), match.group("val")
i = match.end(0)
# Parse the key, value in case it's metainfo
if key[0] == "$":
# We ignore attributes which pertain to the cookie
# mechanism as a whole. See RFC 2109.
# (Does anyone care?)
if M:
M[key[1:]] = value
elif key.lower() in Morsel._reserved:
if M:
if value is None:
if key.lower() in Morsel._flags:
M[key] = True
else:
M[key] = _unquote(value)
elif value is not None:
rval, cval = self.value_decode(value)
self.__set(key, rval, cval)
M = self[key]
class SimpleCookie(BaseCookie):
"""
SimpleCookie supports strings as cookie values. When setting
the value using the dictionary assignment notation, SimpleCookie
calls the builtin str() to convert the value to a string. Values
received from HTTP are kept as strings.
"""
def value_decode(self, val):
return _unquote(val), val
def value_encode(self, val):
strval = str(val)
return strval, _quote(strval)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,940 @@
"""
Miscellaneous function (re)definitions from the Py3.4+ standard library
for Python 2.6/2.7.
- math.ceil (for Python 2.7)
- collections.OrderedDict (for Python 2.6)
- collections.Counter (for Python 2.6)
- collections.ChainMap (for all versions prior to Python 3.3)
- itertools.count (for Python 2.6, with step parameter)
- subprocess.check_output (for Python 2.6)
- reprlib.recursive_repr (for Python 2.6+)
- functools.cmp_to_key (for Python 2.6)
"""
from __future__ import absolute_import
import subprocess
from math import ceil as oldceil
from collections import Mapping, MutableMapping
from operator import itemgetter as _itemgetter, eq as _eq
import sys
import heapq as _heapq
from _weakref import proxy as _proxy
from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
from socket import getaddrinfo, SOCK_STREAM, error, socket
from future.utils import iteritems, itervalues, PY26, PY3
def ceil(x):
"""
Return the ceiling of x as an int.
This is the smallest integral value >= x.
"""
return int(oldceil(x))
########################################################################
### reprlib.recursive_repr decorator from Py3.4
########################################################################
from itertools import islice
if PY3:
try:
from _thread import get_ident
except ImportError:
from _dummy_thread import get_ident
else:
try:
from thread import get_ident
except ImportError:
from dummy_thread import get_ident
def recursive_repr(fillvalue='...'):
'Decorator to make a repr function return fillvalue for a recursive call'
def decorating_function(user_function):
repr_running = set()
def wrapper(self):
key = id(self), get_ident()
if key in repr_running:
return fillvalue
repr_running.add(key)
try:
result = user_function(self)
finally:
repr_running.discard(key)
return result
# Can't use functools.wraps() here because of bootstrap issues
wrapper.__module__ = getattr(user_function, '__module__')
wrapper.__doc__ = getattr(user_function, '__doc__')
wrapper.__name__ = getattr(user_function, '__name__')
wrapper.__annotations__ = getattr(user_function, '__annotations__', {})
return wrapper
return decorating_function
################################################################################
### OrderedDict
################################################################################
class _Link(object):
__slots__ = 'prev', 'next', 'key', '__weakref__'
class OrderedDict(dict):
'Dictionary that remembers insertion order'
# An inherited dict maps keys to values.
# The inherited dict provides __getitem__, __len__, __contains__, and get.
# The remaining methods are order-aware.
# Big-O running times for all methods are the same as regular dictionaries.
# The internal self.__map dict maps keys to links in a doubly linked list.
# The circular doubly linked list starts and ends with a sentinel element.
# The sentinel element never gets deleted (this simplifies the algorithm).
# The sentinel is in self.__hardroot with a weakref proxy in self.__root.
# The prev links are weakref proxies (to prevent circular references).
# Individual links are kept alive by the hard reference in self.__map.
# Those hard references disappear when a key is deleted from an OrderedDict.
def __init__(*args, **kwds):
'''Initialize an ordered dictionary. The signature is the same as
regular dictionaries, but keyword arguments are not recommended because
their insertion order is arbitrary.
'''
if not args:
raise TypeError("descriptor '__init__' of 'OrderedDict' object "
"needs an argument")
self = args[0]
args = args[1:]
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
try:
self.__root
except AttributeError:
self.__hardroot = _Link()
self.__root = root = _proxy(self.__hardroot)
root.prev = root.next = root
self.__map = {}
self.__update(*args, **kwds)
def __setitem__(self, key, value,
dict_setitem=dict.__setitem__, proxy=_proxy, Link=_Link):
'od.__setitem__(i, y) <==> od[i]=y'
# Setting a new item creates a new link at the end of the linked list,
# and the inherited dictionary is updated with the new key/value pair.
if key not in self:
self.__map[key] = link = Link()
root = self.__root
last = root.prev
link.prev, link.next, link.key = last, root, key
last.next = link
root.prev = proxy(link)
dict_setitem(self, key, value)
def __delitem__(self, key, dict_delitem=dict.__delitem__):
'od.__delitem__(y) <==> del od[y]'
# Deleting an existing item uses self.__map to find the link which gets
# removed by updating the links in the predecessor and successor nodes.
dict_delitem(self, key)
link = self.__map.pop(key)
link_prev = link.prev
link_next = link.next
link_prev.next = link_next
link_next.prev = link_prev
def __iter__(self):
'od.__iter__() <==> iter(od)'
# Traverse the linked list in order.
root = self.__root
curr = root.next
while curr is not root:
yield curr.key
curr = curr.next
def __reversed__(self):
'od.__reversed__() <==> reversed(od)'
# Traverse the linked list in reverse order.
root = self.__root
curr = root.prev
while curr is not root:
yield curr.key
curr = curr.prev
def clear(self):
'od.clear() -> None. Remove all items from od.'
root = self.__root
root.prev = root.next = root
self.__map.clear()
dict.clear(self)
def popitem(self, last=True):
'''od.popitem() -> (k, v), return and remove a (key, value) pair.
Pairs are returned in LIFO order if last is true or FIFO order if false.
'''
if not self:
raise KeyError('dictionary is empty')
root = self.__root
if last:
link = root.prev
link_prev = link.prev
link_prev.next = root
root.prev = link_prev
else:
link = root.next
link_next = link.next
root.next = link_next
link_next.prev = root
key = link.key
del self.__map[key]
value = dict.pop(self, key)
return key, value
def move_to_end(self, key, last=True):
'''Move an existing element to the end (or beginning if last==False).
Raises KeyError if the element does not exist.
When last=True, acts like a fast version of self[key]=self.pop(key).
'''
link = self.__map[key]
link_prev = link.prev
link_next = link.next
link_prev.next = link_next
link_next.prev = link_prev
root = self.__root
if last:
last = root.prev
link.prev = last
link.next = root
last.next = root.prev = link
else:
first = root.next
link.prev = root
link.next = first
root.next = first.prev = link
def __sizeof__(self):
sizeof = sys.getsizeof
n = len(self) + 1 # number of links including root
size = sizeof(self.__dict__) # instance dictionary
size += sizeof(self.__map) * 2 # internal dict and inherited dict
size += sizeof(self.__hardroot) * n # link objects
size += sizeof(self.__root) * n # proxy objects
return size
update = __update = MutableMapping.update
keys = MutableMapping.keys
values = MutableMapping.values
items = MutableMapping.items
__ne__ = MutableMapping.__ne__
__marker = object()
def pop(self, key, default=__marker):
'''od.pop(k[,d]) -> v, remove specified key and return the corresponding
value. If key is not found, d is returned if given, otherwise KeyError
is raised.
'''
if key in self:
result = self[key]
del self[key]
return result
if default is self.__marker:
raise KeyError(key)
return default
def setdefault(self, key, default=None):
'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
if key in self:
return self[key]
self[key] = default
return default
@recursive_repr()
def __repr__(self):
'od.__repr__() <==> repr(od)'
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self.items()))
def __reduce__(self):
'Return state information for pickling'
inst_dict = vars(self).copy()
for k in vars(OrderedDict()):
inst_dict.pop(k, None)
return self.__class__, (), inst_dict or None, None, iter(self.items())
def copy(self):
'od.copy() -> a shallow copy of od'
return self.__class__(self)
@classmethod
def fromkeys(cls, iterable, value=None):
'''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S.
If not specified, the value defaults to None.
'''
self = cls()
for key in iterable:
self[key] = value
return self
def __eq__(self, other):
'''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
while comparison to a regular mapping is order-insensitive.
'''
if isinstance(other, OrderedDict):
return dict.__eq__(self, other) and all(map(_eq, self, other))
return dict.__eq__(self, other)
# {{{ http://code.activestate.com/recipes/576611/ (r11)
try:
from operator import itemgetter
from heapq import nlargest
except ImportError:
pass
########################################################################
### Counter
########################################################################
def _count_elements(mapping, iterable):
'Tally elements from the iterable.'
mapping_get = mapping.get
for elem in iterable:
mapping[elem] = mapping_get(elem, 0) + 1
class Counter(dict):
'''Dict subclass for counting hashable items. Sometimes called a bag
or multiset. Elements are stored as dictionary keys and their counts
are stored as dictionary values.
>>> c = Counter('abcdeabcdabcaba') # count elements from a string
>>> c.most_common(3) # three most common elements
[('a', 5), ('b', 4), ('c', 3)]
>>> sorted(c) # list all unique elements
['a', 'b', 'c', 'd', 'e']
>>> ''.join(sorted(c.elements())) # list elements with repetitions
'aaaaabbbbcccdde'
>>> sum(c.values()) # total of all counts
15
>>> c['a'] # count of letter 'a'
5
>>> for elem in 'shazam': # update counts from an iterable
... c[elem] += 1 # by adding 1 to each element's count
>>> c['a'] # now there are seven 'a'
7
>>> del c['b'] # remove all 'b'
>>> c['b'] # now there are zero 'b'
0
>>> d = Counter('simsalabim') # make another counter
>>> c.update(d) # add in the second counter
>>> c['a'] # now there are nine 'a'
9
>>> c.clear() # empty the counter
>>> c
Counter()
Note: If a count is set to zero or reduced to zero, it will remain
in the counter until the entry is deleted or the counter is cleared:
>>> c = Counter('aaabbc')
>>> c['b'] -= 2 # reduce the count of 'b' by two
>>> c.most_common() # 'b' is still in, but its count is zero
[('a', 3), ('c', 1), ('b', 0)]
'''
# References:
# http://en.wikipedia.org/wiki/Multiset
# http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html
# http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm
# http://code.activestate.com/recipes/259174/
# Knuth, TAOCP Vol. II section 4.6.3
def __init__(*args, **kwds):
'''Create a new, empty Counter object. And if given, count elements
from an input iterable. Or, initialize the count from another mapping
of elements to their counts.
>>> c = Counter() # a new, empty counter
>>> c = Counter('gallahad') # a new counter from an iterable
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
>>> c = Counter(a=4, b=2) # a new counter from keyword args
'''
if not args:
raise TypeError("descriptor '__init__' of 'Counter' object "
"needs an argument")
self = args[0]
args = args[1:]
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
super(Counter, self).__init__()
self.update(*args, **kwds)
def __missing__(self, key):
'The count of elements not in the Counter is zero.'
# Needed so that self[missing_item] does not raise KeyError
return 0
def most_common(self, n=None):
'''List the n most common elements and their counts from the most
common to the least. If n is None, then list all element counts.
>>> Counter('abcdeabcdabcaba').most_common(3)
[('a', 5), ('b', 4), ('c', 3)]
'''
# Emulate Bag.sortedByCount from Smalltalk
if n is None:
return sorted(self.items(), key=_itemgetter(1), reverse=True)
return _heapq.nlargest(n, self.items(), key=_itemgetter(1))
def elements(self):
'''Iterator over elements repeating each as many times as its count.
>>> c = Counter('ABCABC')
>>> sorted(c.elements())
['A', 'A', 'B', 'B', 'C', 'C']
# Knuth's example for prime factors of 1836: 2**2 * 3**3 * 17**1
>>> prime_factors = Counter({2: 2, 3: 3, 17: 1})
>>> product = 1
>>> for factor in prime_factors.elements(): # loop over factors
... product *= factor # and multiply them
>>> product
1836
Note, if an element's count has been set to zero or is a negative
number, elements() will ignore it.
'''
# Emulate Bag.do from Smalltalk and Multiset.begin from C++.
return _chain.from_iterable(_starmap(_repeat, self.items()))
# Override dict methods where necessary
@classmethod
def fromkeys(cls, iterable, v=None):
# There is no equivalent method for counters because setting v=1
# means that no element can have a count greater than one.
raise NotImplementedError(
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
def update(*args, **kwds):
'''Like dict.update() but add counts instead of replacing them.
Source can be an iterable, a dictionary, or another Counter instance.
>>> c = Counter('which')
>>> c.update('witch') # add elements from another iterable
>>> d = Counter('watch')
>>> c.update(d) # add elements from another counter
>>> c['h'] # four 'h' in which, witch, and watch
4
'''
# The regular dict.update() operation makes no sense here because the
# replace behavior results in the some of original untouched counts
# being mixed-in with all of the other counts for a mismash that
# doesn't have a straight-forward interpretation in most counting
# contexts. Instead, we implement straight-addition. Both the inputs
# and outputs are allowed to contain zero and negative counts.
if not args:
raise TypeError("descriptor 'update' of 'Counter' object "
"needs an argument")
self = args[0]
args = args[1:]
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
iterable = args[0] if args else None
if iterable is not None:
if isinstance(iterable, Mapping):
if self:
self_get = self.get
for elem, count in iterable.items():
self[elem] = count + self_get(elem, 0)
else:
super(Counter, self).update(iterable) # fast path when counter is empty
else:
_count_elements(self, iterable)
if kwds:
self.update(kwds)
def subtract(*args, **kwds):
'''Like dict.update() but subtracts counts instead of replacing them.
Counts can be reduced below zero. Both the inputs and outputs are
allowed to contain zero and negative counts.
Source can be an iterable, a dictionary, or another Counter instance.
>>> c = Counter('which')
>>> c.subtract('witch') # subtract elements from another iterable
>>> c.subtract(Counter('watch')) # subtract elements from another counter
>>> c['h'] # 2 in which, minus 1 in witch, minus 1 in watch
0
>>> c['w'] # 1 in which, minus 1 in witch, minus 1 in watch
-1
'''
if not args:
raise TypeError("descriptor 'subtract' of 'Counter' object "
"needs an argument")
self = args[0]
args = args[1:]
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
iterable = args[0] if args else None
if iterable is not None:
self_get = self.get
if isinstance(iterable, Mapping):
for elem, count in iterable.items():
self[elem] = self_get(elem, 0) - count
else:
for elem in iterable:
self[elem] = self_get(elem, 0) - 1
if kwds:
self.subtract(kwds)
def copy(self):
'Return a shallow copy.'
return self.__class__(self)
def __reduce__(self):
return self.__class__, (dict(self),)
def __delitem__(self, elem):
'Like dict.__delitem__() but does not raise KeyError for missing values.'
if elem in self:
super(Counter, self).__delitem__(elem)
def __repr__(self):
if not self:
return '%s()' % self.__class__.__name__
try:
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
return '%s({%s})' % (self.__class__.__name__, items)
except TypeError:
# handle case where values are not orderable
return '{0}({1!r})'.format(self.__class__.__name__, dict(self))
# Multiset-style mathematical operations discussed in:
# Knuth TAOCP Volume II section 4.6.3 exercise 19
# and at http://en.wikipedia.org/wiki/Multiset
#
# Outputs guaranteed to only include positive counts.
#
# To strip negative and zero counts, add-in an empty counter:
# c += Counter()
def __add__(self, other):
'''Add counts from two counters.
>>> Counter('abbb') + Counter('bcc')
Counter({'b': 4, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem, count in self.items():
newcount = count + other[elem]
if newcount > 0:
result[elem] = newcount
for elem, count in other.items():
if elem not in self and count > 0:
result[elem] = count
return result
def __sub__(self, other):
''' Subtract count, but keep only results with positive counts.
>>> Counter('abbbc') - Counter('bccd')
Counter({'b': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem, count in self.items():
newcount = count - other[elem]
if newcount > 0:
result[elem] = newcount
for elem, count in other.items():
if elem not in self and count < 0:
result[elem] = 0 - count
return result
def __or__(self, other):
'''Union is the maximum of value in either of the input counters.
>>> Counter('abbb') | Counter('bcc')
Counter({'b': 3, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem, count in self.items():
other_count = other[elem]
newcount = other_count if count < other_count else count
if newcount > 0:
result[elem] = newcount
for elem, count in other.items():
if elem not in self and count > 0:
result[elem] = count
return result
def __and__(self, other):
''' Intersection is the minimum of corresponding counts.
>>> Counter('abbb') & Counter('bcc')
Counter({'b': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem, count in self.items():
other_count = other[elem]
newcount = count if count < other_count else other_count
if newcount > 0:
result[elem] = newcount
return result
def __pos__(self):
'Adds an empty counter, effectively stripping negative and zero counts'
return self + Counter()
def __neg__(self):
'''Subtracts from an empty counter. Strips positive and zero counts,
and flips the sign on negative counts.
'''
return Counter() - self
def _keep_positive(self):
'''Internal method to strip elements with a negative or zero count'''
nonpositive = [elem for elem, count in self.items() if not count > 0]
for elem in nonpositive:
del self[elem]
return self
def __iadd__(self, other):
'''Inplace add from another counter, keeping only positive counts.
>>> c = Counter('abbb')
>>> c += Counter('bcc')
>>> c
Counter({'b': 4, 'c': 2, 'a': 1})
'''
for elem, count in other.items():
self[elem] += count
return self._keep_positive()
def __isub__(self, other):
'''Inplace subtract counter, but keep only results with positive counts.
>>> c = Counter('abbbc')
>>> c -= Counter('bccd')
>>> c
Counter({'b': 2, 'a': 1})
'''
for elem, count in other.items():
self[elem] -= count
return self._keep_positive()
def __ior__(self, other):
'''Inplace union is the maximum of value from either counter.
>>> c = Counter('abbb')
>>> c |= Counter('bcc')
>>> c
Counter({'b': 3, 'c': 2, 'a': 1})
'''
for elem, other_count in other.items():
count = self[elem]
if other_count > count:
self[elem] = other_count
return self._keep_positive()
def __iand__(self, other):
'''Inplace intersection is the minimum of corresponding counts.
>>> c = Counter('abbb')
>>> c &= Counter('bcc')
>>> c
Counter({'b': 1})
'''
for elem, count in self.items():
other_count = other[elem]
if other_count < count:
self[elem] = other_count
return self._keep_positive()
def check_output(*popenargs, **kwargs):
"""
For Python 2.6 compatibility: see
http://stackoverflow.com/questions/4814970/
"""
if 'stdout' in kwargs:
raise ValueError('stdout argument not allowed, it will be overridden.')
process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
output, unused_err = process.communicate()
retcode = process.poll()
if retcode:
cmd = kwargs.get("args")
if cmd is None:
cmd = popenargs[0]
raise subprocess.CalledProcessError(retcode, cmd)
return output
def count(start=0, step=1):
"""
``itertools.count`` in Py 2.6 doesn't accept a step
parameter. This is an enhanced version of ``itertools.count``
for Py2.6 equivalent to ``itertools.count`` in Python 2.7+.
"""
while True:
yield start
start += step
########################################################################
### ChainMap (helper for configparser and string.Template)
### From the Py3.4 source code. See also:
### https://github.com/kkxue/Py2ChainMap/blob/master/py2chainmap.py
########################################################################
class ChainMap(MutableMapping):
''' A ChainMap groups multiple dicts (or other mappings) together
to create a single, updateable view.
The underlying mappings are stored in a list. That list is public and can
accessed or updated using the *maps* attribute. There is no other state.
Lookups search the underlying mappings successively until a key is found.
In contrast, writes, updates, and deletions only operate on the first
mapping.
'''
def __init__(self, *maps):
'''Initialize a ChainMap by setting *maps* to the given mappings.
If no mappings are provided, a single empty dictionary is used.
'''
self.maps = list(maps) or [{}] # always at least one map
def __missing__(self, key):
raise KeyError(key)
def __getitem__(self, key):
for mapping in self.maps:
try:
return mapping[key] # can't use 'key in mapping' with defaultdict
except KeyError:
pass
return self.__missing__(key) # support subclasses that define __missing__
def get(self, key, default=None):
return self[key] if key in self else default
def __len__(self):
return len(set().union(*self.maps)) # reuses stored hash values if possible
def __iter__(self):
return iter(set().union(*self.maps))
def __contains__(self, key):
return any(key in m for m in self.maps)
def __bool__(self):
return any(self.maps)
# Py2 compatibility:
__nonzero__ = __bool__
@recursive_repr()
def __repr__(self):
return '{0.__class__.__name__}({1})'.format(
self, ', '.join(map(repr, self.maps)))
@classmethod
def fromkeys(cls, iterable, *args):
'Create a ChainMap with a single dict created from the iterable.'
return cls(dict.fromkeys(iterable, *args))
def copy(self):
'New ChainMap or subclass with a new copy of maps[0] and refs to maps[1:]'
return self.__class__(self.maps[0].copy(), *self.maps[1:])
__copy__ = copy
def new_child(self, m=None): # like Django's Context.push()
'''
New ChainMap with a new map followed by all previous maps. If no
map is provided, an empty dict is used.
'''
if m is None:
m = {}
return self.__class__(m, *self.maps)
@property
def parents(self): # like Django's Context.pop()
'New ChainMap from maps[1:].'
return self.__class__(*self.maps[1:])
def __setitem__(self, key, value):
self.maps[0][key] = value
def __delitem__(self, key):
try:
del self.maps[0][key]
except KeyError:
raise KeyError('Key not found in the first mapping: {0!r}'.format(key))
def popitem(self):
'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.'
try:
return self.maps[0].popitem()
except KeyError:
raise KeyError('No keys found in the first mapping.')
def pop(self, key, *args):
'Remove *key* from maps[0] and return its value. Raise KeyError if *key* not in maps[0].'
try:
return self.maps[0].pop(key, *args)
except KeyError:
raise KeyError('Key not found in the first mapping: {0!r}'.format(key))
def clear(self):
'Clear maps[0], leaving maps[1:] intact.'
self.maps[0].clear()
# Re-use the same sentinel as in the Python stdlib socket module:
from socket import _GLOBAL_DEFAULT_TIMEOUT
# Was: _GLOBAL_DEFAULT_TIMEOUT = object()
def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
source_address=None):
"""Backport of 3-argument create_connection() for Py2.6.
Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
err = None
for res in getaddrinfo(host, port, 0, SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket(af, socktype, proto)
if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except error as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise error("getaddrinfo returns an empty list")
# Backport from Py2.7 for Py2.6:
def cmp_to_key(mycmp):
"""Convert a cmp= function into a key= function"""
class K(object):
__slots__ = ['obj']
def __init__(self, obj, *args):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj) < 0
def __gt__(self, other):
return mycmp(self.obj, other.obj) > 0
def __eq__(self, other):
return mycmp(self.obj, other.obj) == 0
def __le__(self, other):
return mycmp(self.obj, other.obj) <= 0
def __ge__(self, other):
return mycmp(self.obj, other.obj) >= 0
def __ne__(self, other):
return mycmp(self.obj, other.obj) != 0
def __hash__(self):
raise TypeError('hash not implemented')
return K
# Back up our definitions above in case they're useful
_OrderedDict = OrderedDict
_Counter = Counter
_check_output = check_output
_count = count
_ceil = ceil
__count_elements = _count_elements
_recursive_repr = recursive_repr
_ChainMap = ChainMap
_create_connection = create_connection
_cmp_to_key = cmp_to_key
# Overwrite the definitions above with the usual ones
# from the standard library:
if sys.version_info >= (2, 7):
from collections import OrderedDict, Counter
from itertools import count
from functools import cmp_to_key
try:
from subprocess import check_output
except ImportError:
# Not available. This happens with Google App Engine: see issue #231
pass
from socket import create_connection
if sys.version_info >= (3, 0):
from math import ceil
from collections import _count_elements
if sys.version_info >= (3, 3):
from reprlib import recursive_repr
from collections import ChainMap

View file

@ -0,0 +1,454 @@
# Wrapper module for _socket, providing some additional facilities
# implemented in Python.
"""\
This module provides socket operations and some related functions.
On Unix, it supports IP (Internet Protocol) and Unix domain sockets.
On other systems, it only supports IP. Functions specific for a
socket are available as methods of the socket object.
Functions:
socket() -- create a new socket object
socketpair() -- create a pair of new socket objects [*]
fromfd() -- create a socket object from an open file descriptor [*]
fromshare() -- create a socket object from data received from socket.share() [*]
gethostname() -- return the current hostname
gethostbyname() -- map a hostname to its IP number
gethostbyaddr() -- map an IP number or hostname to DNS info
getservbyname() -- map a service name and a protocol name to a port number
getprotobyname() -- map a protocol name (e.g. 'tcp') to a number
ntohs(), ntohl() -- convert 16, 32 bit int from network to host byte order
htons(), htonl() -- convert 16, 32 bit int from host to network byte order
inet_aton() -- convert IP addr string (123.45.67.89) to 32-bit packed format
inet_ntoa() -- convert 32-bit packed format IP to string (123.45.67.89)
socket.getdefaulttimeout() -- get the default timeout value
socket.setdefaulttimeout() -- set the default timeout value
create_connection() -- connects to an address, with an optional timeout and
optional source address.
[*] not available on all platforms!
Special objects:
SocketType -- type object for socket objects
error -- exception raised for I/O errors
has_ipv6 -- boolean value indicating if IPv6 is supported
Integer constants:
AF_INET, AF_UNIX -- socket domains (first argument to socket() call)
SOCK_STREAM, SOCK_DGRAM, SOCK_RAW -- socket types (second argument)
Many other constants may be defined; these may be used in calls to
the setsockopt() and getsockopt() methods.
"""
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from future.builtins import super
import _socket
from _socket import *
import os, sys, io
try:
import errno
except ImportError:
errno = None
EBADF = getattr(errno, 'EBADF', 9)
EAGAIN = getattr(errno, 'EAGAIN', 11)
EWOULDBLOCK = getattr(errno, 'EWOULDBLOCK', 11)
__all__ = ["getfqdn", "create_connection"]
__all__.extend(os._get_exports_list(_socket))
_realsocket = socket
# WSA error codes
if sys.platform.lower().startswith("win"):
errorTab = {}
errorTab[10004] = "The operation was interrupted."
errorTab[10009] = "A bad file handle was passed."
errorTab[10013] = "Permission denied."
errorTab[10014] = "A fault occurred on the network??" # WSAEFAULT
errorTab[10022] = "An invalid operation was attempted."
errorTab[10035] = "The socket operation would block"
errorTab[10036] = "A blocking operation is already in progress."
errorTab[10048] = "The network address is in use."
errorTab[10054] = "The connection has been reset."
errorTab[10058] = "The network has been shut down."
errorTab[10060] = "The operation timed out."
errorTab[10061] = "Connection refused."
errorTab[10063] = "The name is too long."
errorTab[10064] = "The host is down."
errorTab[10065] = "The host is unreachable."
__all__.append("errorTab")
class socket(_socket.socket):
"""A subclass of _socket.socket adding the makefile() method."""
__slots__ = ["__weakref__", "_io_refs", "_closed"]
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
if fileno is None:
_socket.socket.__init__(self, family, type, proto)
else:
_socket.socket.__init__(self, family, type, proto, fileno)
self._io_refs = 0
self._closed = False
def __enter__(self):
return self
def __exit__(self, *args):
if not self._closed:
self.close()
def __repr__(self):
"""Wrap __repr__() to reveal the real class name."""
s = _socket.socket.__repr__(self)
if s.startswith("<socket object"):
s = "<%s.%s%s%s" % (self.__class__.__module__,
self.__class__.__name__,
getattr(self, '_closed', False) and " [closed] " or "",
s[7:])
return s
def __getstate__(self):
raise TypeError("Cannot serialize socket object")
def dup(self):
"""dup() -> socket object
Return a new socket object connected to the same system resource.
"""
fd = dup(self.fileno())
sock = self.__class__(self.family, self.type, self.proto, fileno=fd)
sock.settimeout(self.gettimeout())
return sock
def accept(self):
"""accept() -> (socket object, address info)
Wait for an incoming connection. Return a new socket
representing the connection, and the address of the client.
For IP sockets, the address info is a pair (hostaddr, port).
"""
fd, addr = self._accept()
sock = socket(self.family, self.type, self.proto, fileno=fd)
# Issue #7995: if no default timeout is set and the listening
# socket had a (non-zero) timeout, force the new socket in blocking
# mode to override platform-specific socket flags inheritance.
if getdefaulttimeout() is None and self.gettimeout():
sock.setblocking(True)
return sock, addr
def makefile(self, mode="r", buffering=None, **_3to2kwargs):
"""makefile(...) -> an I/O stream connected to the socket
The arguments are as for io.open() after the filename,
except the only mode characters supported are 'r', 'w' and 'b'.
The semantics are similar too. (XXX refactor to share code?)
"""
if 'newline' in _3to2kwargs: newline = _3to2kwargs['newline']; del _3to2kwargs['newline']
else: newline = None
if 'errors' in _3to2kwargs: errors = _3to2kwargs['errors']; del _3to2kwargs['errors']
else: errors = None
if 'encoding' in _3to2kwargs: encoding = _3to2kwargs['encoding']; del _3to2kwargs['encoding']
else: encoding = None
for c in mode:
if c not in ("r", "w", "b"):
raise ValueError("invalid mode %r (only r, w, b allowed)")
writing = "w" in mode
reading = "r" in mode or not writing
assert reading or writing
binary = "b" in mode
rawmode = ""
if reading:
rawmode += "r"
if writing:
rawmode += "w"
raw = SocketIO(self, rawmode)
self._io_refs += 1
if buffering is None:
buffering = -1
if buffering < 0:
buffering = io.DEFAULT_BUFFER_SIZE
if buffering == 0:
if not binary:
raise ValueError("unbuffered streams must be binary")
return raw
if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering)
elif reading:
buffer = io.BufferedReader(raw, buffering)
else:
assert writing
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text
def _decref_socketios(self):
if self._io_refs > 0:
self._io_refs -= 1
if self._closed:
self.close()
def _real_close(self, _ss=_socket.socket):
# This function should not reference any globals. See issue #808164.
_ss.close(self)
def close(self):
# This function should not reference any globals. See issue #808164.
self._closed = True
if self._io_refs <= 0:
self._real_close()
def detach(self):
"""detach() -> file descriptor
Close the socket object without closing the underlying file descriptor.
The object cannot be used after this call, but the file descriptor
can be reused for other purposes. The file descriptor is returned.
"""
self._closed = True
return super().detach()
def fromfd(fd, family, type, proto=0):
""" fromfd(fd, family, type[, proto]) -> socket object
Create a socket object from a duplicate of the given file
descriptor. The remaining arguments are the same as for socket().
"""
nfd = dup(fd)
return socket(family, type, proto, nfd)
if hasattr(_socket.socket, "share"):
def fromshare(info):
""" fromshare(info) -> socket object
Create a socket object from a the bytes object returned by
socket.share(pid).
"""
return socket(0, 0, 0, info)
if hasattr(_socket, "socketpair"):
def socketpair(family=None, type=SOCK_STREAM, proto=0):
"""socketpair([family[, type[, proto]]]) -> (socket object, socket object)
Create a pair of socket objects from the sockets returned by the platform
socketpair() function.
The arguments are the same as for socket() except the default family is
AF_UNIX if defined on the platform; otherwise, the default is AF_INET.
"""
if family is None:
try:
family = AF_UNIX
except NameError:
family = AF_INET
a, b = _socket.socketpair(family, type, proto)
a = socket(family, type, proto, a.detach())
b = socket(family, type, proto, b.detach())
return a, b
_blocking_errnos = set([EAGAIN, EWOULDBLOCK])
class SocketIO(io.RawIOBase):
"""Raw I/O implementation for stream sockets.
This class supports the makefile() method on sockets. It provides
the raw I/O interface on top of a socket object.
"""
# One might wonder why not let FileIO do the job instead. There are two
# main reasons why FileIO is not adapted:
# - it wouldn't work under Windows (where you can't used read() and
# write() on a socket handle)
# - it wouldn't work with socket timeouts (FileIO would ignore the
# timeout and consider the socket non-blocking)
# XXX More docs
def __init__(self, sock, mode):
if mode not in ("r", "w", "rw", "rb", "wb", "rwb"):
raise ValueError("invalid mode: %r" % mode)
io.RawIOBase.__init__(self)
self._sock = sock
if "b" not in mode:
mode += "b"
self._mode = mode
self._reading = "r" in mode
self._writing = "w" in mode
self._timeout_occurred = False
def readinto(self, b):
"""Read up to len(b) bytes into the writable buffer *b* and return
the number of bytes read. If the socket is non-blocking and no bytes
are available, None is returned.
If *b* is non-empty, a 0 return value indicates that the connection
was shutdown at the other end.
"""
self._checkClosed()
self._checkReadable()
if self._timeout_occurred:
raise IOError("cannot read from timed out object")
while True:
try:
return self._sock.recv_into(b)
except timeout:
self._timeout_occurred = True
raise
# except InterruptedError:
# continue
except error as e:
if e.args[0] in _blocking_errnos:
return None
raise
def write(self, b):
"""Write the given bytes or bytearray object *b* to the socket
and return the number of bytes written. This can be less than
len(b) if not all data could be written. If the socket is
non-blocking and no bytes could be written None is returned.
"""
self._checkClosed()
self._checkWritable()
try:
return self._sock.send(b)
except error as e:
# XXX what about EINTR?
if e.args[0] in _blocking_errnos:
return None
raise
def readable(self):
"""True if the SocketIO is open for reading.
"""
if self.closed:
raise ValueError("I/O operation on closed socket.")
return self._reading
def writable(self):
"""True if the SocketIO is open for writing.
"""
if self.closed:
raise ValueError("I/O operation on closed socket.")
return self._writing
def seekable(self):
"""True if the SocketIO is open for seeking.
"""
if self.closed:
raise ValueError("I/O operation on closed socket.")
return super().seekable()
def fileno(self):
"""Return the file descriptor of the underlying socket.
"""
self._checkClosed()
return self._sock.fileno()
@property
def name(self):
if not self.closed:
return self.fileno()
else:
return -1
@property
def mode(self):
return self._mode
def close(self):
"""Close the SocketIO object. This doesn't close the underlying
socket, except if all references to it have disappeared.
"""
if self.closed:
return
io.RawIOBase.close(self)
self._sock._decref_socketios()
self._sock = None
def getfqdn(name=''):
"""Get fully qualified domain name from name.
An empty argument is interpreted as meaning the local host.
First the hostname returned by gethostbyaddr() is checked, then
possibly existing aliases. In case no FQDN is available, hostname
from gethostname() is returned.
"""
name = name.strip()
if not name or name == '0.0.0.0':
name = gethostname()
try:
hostname, aliases, ipaddrs = gethostbyaddr(name)
except error:
pass
else:
aliases.insert(0, hostname)
for name in aliases:
if '.' in name:
break
else:
name = hostname
return name
# Re-use the same sentinel as in the Python stdlib socket module:
from socket import _GLOBAL_DEFAULT_TIMEOUT
# Was: _GLOBAL_DEFAULT_TIMEOUT = object()
def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
source_address=None):
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
err = None
for res in getaddrinfo(host, port, 0, SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket(af, socktype, proto)
if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except error as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise error("getaddrinfo returns an empty list")

View file

@ -0,0 +1,747 @@
"""Generic socket server classes.
This module tries to capture the various aspects of defining a server:
For socket-based servers:
- address family:
- AF_INET{,6}: IP (Internet Protocol) sockets (default)
- AF_UNIX: Unix domain sockets
- others, e.g. AF_DECNET are conceivable (see <socket.h>
- socket type:
- SOCK_STREAM (reliable stream, e.g. TCP)
- SOCK_DGRAM (datagrams, e.g. UDP)
For request-based servers (including socket-based):
- client address verification before further looking at the request
(This is actually a hook for any processing that needs to look
at the request before anything else, e.g. logging)
- how to handle multiple requests:
- synchronous (one request is handled at a time)
- forking (each request is handled by a new process)
- threading (each request is handled by a new thread)
The classes in this module favor the server type that is simplest to
write: a synchronous TCP/IP server. This is bad class design, but
save some typing. (There's also the issue that a deep class hierarchy
slows down method lookups.)
There are five classes in an inheritance diagram, four of which represent
synchronous servers of four types:
+------------+
| BaseServer |
+------------+
|
v
+-----------+ +------------------+
| TCPServer |------->| UnixStreamServer |
+-----------+ +------------------+
|
v
+-----------+ +--------------------+
| UDPServer |------->| UnixDatagramServer |
+-----------+ +--------------------+
Note that UnixDatagramServer derives from UDPServer, not from
UnixStreamServer -- the only difference between an IP and a Unix
stream server is the address family, which is simply repeated in both
unix server classes.
Forking and threading versions of each type of server can be created
using the ForkingMixIn and ThreadingMixIn mix-in classes. For
instance, a threading UDP server class is created as follows:
class ThreadingUDPServer(ThreadingMixIn, UDPServer): pass
The Mix-in class must come first, since it overrides a method defined
in UDPServer! Setting the various member variables also changes
the behavior of the underlying server mechanism.
To implement a service, you must derive a class from
BaseRequestHandler and redefine its handle() method. You can then run
various versions of the service by combining one of the server classes
with your request handler class.
The request handler class must be different for datagram or stream
services. This can be hidden by using the request handler
subclasses StreamRequestHandler or DatagramRequestHandler.
Of course, you still have to use your head!
For instance, it makes no sense to use a forking server if the service
contains state in memory that can be modified by requests (since the
modifications in the child process would never reach the initial state
kept in the parent process and passed to each child). In this case,
you can use a threading server, but you will probably have to use
locks to avoid two requests that come in nearly simultaneous to apply
conflicting changes to the server state.
On the other hand, if you are building e.g. an HTTP server, where all
data is stored externally (e.g. in the file system), a synchronous
class will essentially render the service "deaf" while one request is
being handled -- which may be for a very long time if a client is slow
to read all the data it has requested. Here a threading or forking
server is appropriate.
In some cases, it may be appropriate to process part of a request
synchronously, but to finish processing in a forked child depending on
the request data. This can be implemented by using a synchronous
server and doing an explicit fork in the request handler class
handle() method.
Another approach to handling multiple simultaneous requests in an
environment that supports neither threads nor fork (or where these are
too expensive or inappropriate for the service) is to maintain an
explicit table of partially finished requests and to use select() to
decide which request to work on next (or whether to handle a new
incoming request). This is particularly important for stream services
where each client can potentially be connected for a long time (if
threads or subprocesses cannot be used).
Future work:
- Standard classes for Sun RPC (which uses either UDP or TCP)
- Standard mix-in classes to implement various authentication
and encryption schemes
- Standard framework for select-based multiplexing
XXX Open problems:
- What to do with out-of-band data?
BaseServer:
- split generic "request" functionality out into BaseServer class.
Copyright (C) 2000 Luke Kenneth Casson Leighton <lkcl@samba.org>
example: read entries from a SQL database (requires overriding
get_request() to return a table entry from the database).
entry is processed by a RequestHandlerClass.
"""
# Author of the BaseServer patch: Luke Kenneth Casson Leighton
# XXX Warning!
# There is a test suite for this module, but it cannot be run by the
# standard regression test.
# To run it manually, run Lib/test/test_socketserver.py.
from __future__ import (absolute_import, print_function)
__version__ = "0.4"
import socket
import select
import sys
import os
import errno
try:
import threading
except ImportError:
import dummy_threading as threading
__all__ = ["TCPServer","UDPServer","ForkingUDPServer","ForkingTCPServer",
"ThreadingUDPServer","ThreadingTCPServer","BaseRequestHandler",
"StreamRequestHandler","DatagramRequestHandler",
"ThreadingMixIn", "ForkingMixIn"]
if hasattr(socket, "AF_UNIX"):
__all__.extend(["UnixStreamServer","UnixDatagramServer",
"ThreadingUnixStreamServer",
"ThreadingUnixDatagramServer"])
def _eintr_retry(func, *args):
"""restart a system call interrupted by EINTR"""
while True:
try:
return func(*args)
except OSError as e:
if e.errno != errno.EINTR:
raise
class BaseServer(object):
"""Base class for server classes.
Methods for the caller:
- __init__(server_address, RequestHandlerClass)
- serve_forever(poll_interval=0.5)
- shutdown()
- handle_request() # if you do not use serve_forever()
- fileno() -> int # for select()
Methods that may be overridden:
- server_bind()
- server_activate()
- get_request() -> request, client_address
- handle_timeout()
- verify_request(request, client_address)
- server_close()
- process_request(request, client_address)
- shutdown_request(request)
- close_request(request)
- service_actions()
- handle_error()
Methods for derived classes:
- finish_request(request, client_address)
Class variables that may be overridden by derived classes or
instances:
- timeout
- address_family
- socket_type
- allow_reuse_address
Instance variables:
- RequestHandlerClass
- socket
"""
timeout = None
def __init__(self, server_address, RequestHandlerClass):
"""Constructor. May be extended, do not override."""
self.server_address = server_address
self.RequestHandlerClass = RequestHandlerClass
self.__is_shut_down = threading.Event()
self.__shutdown_request = False
def server_activate(self):
"""Called by constructor to activate the server.
May be overridden.
"""
pass
def serve_forever(self, poll_interval=0.5):
"""Handle one request at a time until shutdown.
Polls for shutdown every poll_interval seconds. Ignores
self.timeout. If you need to do periodic tasks, do them in
another thread.
"""
self.__is_shut_down.clear()
try:
while not self.__shutdown_request:
# XXX: Consider using another file descriptor or
# connecting to the socket to wake this up instead of
# polling. Polling reduces our responsiveness to a
# shutdown request and wastes cpu at all other times.
r, w, e = _eintr_retry(select.select, [self], [], [],
poll_interval)
if self in r:
self._handle_request_noblock()
self.service_actions()
finally:
self.__shutdown_request = False
self.__is_shut_down.set()
def shutdown(self):
"""Stops the serve_forever loop.
Blocks until the loop has finished. This must be called while
serve_forever() is running in another thread, or it will
deadlock.
"""
self.__shutdown_request = True
self.__is_shut_down.wait()
def service_actions(self):
"""Called by the serve_forever() loop.
May be overridden by a subclass / Mixin to implement any code that
needs to be run during the loop.
"""
pass
# The distinction between handling, getting, processing and
# finishing a request is fairly arbitrary. Remember:
#
# - handle_request() is the top-level call. It calls
# select, get_request(), verify_request() and process_request()
# - get_request() is different for stream or datagram sockets
# - process_request() is the place that may fork a new process
# or create a new thread to finish the request
# - finish_request() instantiates the request handler class;
# this constructor will handle the request all by itself
def handle_request(self):
"""Handle one request, possibly blocking.
Respects self.timeout.
"""
# Support people who used socket.settimeout() to escape
# handle_request before self.timeout was available.
timeout = self.socket.gettimeout()
if timeout is None:
timeout = self.timeout
elif self.timeout is not None:
timeout = min(timeout, self.timeout)
fd_sets = _eintr_retry(select.select, [self], [], [], timeout)
if not fd_sets[0]:
self.handle_timeout()
return
self._handle_request_noblock()
def _handle_request_noblock(self):
"""Handle one request, without blocking.
I assume that select.select has returned that the socket is
readable before this function was called, so there should be
no risk of blocking in get_request().
"""
try:
request, client_address = self.get_request()
except socket.error:
return
if self.verify_request(request, client_address):
try:
self.process_request(request, client_address)
except:
self.handle_error(request, client_address)
self.shutdown_request(request)
def handle_timeout(self):
"""Called if no new request arrives within self.timeout.
Overridden by ForkingMixIn.
"""
pass
def verify_request(self, request, client_address):
"""Verify the request. May be overridden.
Return True if we should proceed with this request.
"""
return True
def process_request(self, request, client_address):
"""Call finish_request.
Overridden by ForkingMixIn and ThreadingMixIn.
"""
self.finish_request(request, client_address)
self.shutdown_request(request)
def server_close(self):
"""Called to clean-up the server.
May be overridden.
"""
pass
def finish_request(self, request, client_address):
"""Finish one request by instantiating RequestHandlerClass."""
self.RequestHandlerClass(request, client_address, self)
def shutdown_request(self, request):
"""Called to shutdown and close an individual request."""
self.close_request(request)
def close_request(self, request):
"""Called to clean up an individual request."""
pass
def handle_error(self, request, client_address):
"""Handle an error gracefully. May be overridden.
The default is to print a traceback and continue.
"""
print('-'*40)
print('Exception happened during processing of request from', end=' ')
print(client_address)
import traceback
traceback.print_exc() # XXX But this goes to stderr!
print('-'*40)
class TCPServer(BaseServer):
"""Base class for various socket-based server classes.
Defaults to synchronous IP stream (i.e., TCP).
Methods for the caller:
- __init__(server_address, RequestHandlerClass, bind_and_activate=True)
- serve_forever(poll_interval=0.5)
- shutdown()
- handle_request() # if you don't use serve_forever()
- fileno() -> int # for select()
Methods that may be overridden:
- server_bind()
- server_activate()
- get_request() -> request, client_address
- handle_timeout()
- verify_request(request, client_address)
- process_request(request, client_address)
- shutdown_request(request)
- close_request(request)
- handle_error()
Methods for derived classes:
- finish_request(request, client_address)
Class variables that may be overridden by derived classes or
instances:
- timeout
- address_family
- socket_type
- request_queue_size (only for stream sockets)
- allow_reuse_address
Instance variables:
- server_address
- RequestHandlerClass
- socket
"""
address_family = socket.AF_INET
socket_type = socket.SOCK_STREAM
request_queue_size = 5
allow_reuse_address = False
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
"""Constructor. May be extended, do not override."""
BaseServer.__init__(self, server_address, RequestHandlerClass)
self.socket = socket.socket(self.address_family,
self.socket_type)
if bind_and_activate:
self.server_bind()
self.server_activate()
def server_bind(self):
"""Called by constructor to bind the socket.
May be overridden.
"""
if self.allow_reuse_address:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.server_address)
self.server_address = self.socket.getsockname()
def server_activate(self):
"""Called by constructor to activate the server.
May be overridden.
"""
self.socket.listen(self.request_queue_size)
def server_close(self):
"""Called to clean-up the server.
May be overridden.
"""
self.socket.close()
def fileno(self):
"""Return socket file number.
Interface required by select().
"""
return self.socket.fileno()
def get_request(self):
"""Get the request and client address from the socket.
May be overridden.
"""
return self.socket.accept()
def shutdown_request(self, request):
"""Called to shutdown and close an individual request."""
try:
#explicitly shutdown. socket.close() merely releases
#the socket and waits for GC to perform the actual close.
request.shutdown(socket.SHUT_WR)
except socket.error:
pass #some platforms may raise ENOTCONN here
self.close_request(request)
def close_request(self, request):
"""Called to clean up an individual request."""
request.close()
class UDPServer(TCPServer):
"""UDP server class."""
allow_reuse_address = False
socket_type = socket.SOCK_DGRAM
max_packet_size = 8192
def get_request(self):
data, client_addr = self.socket.recvfrom(self.max_packet_size)
return (data, self.socket), client_addr
def server_activate(self):
# No need to call listen() for UDP.
pass
def shutdown_request(self, request):
# No need to shutdown anything.
self.close_request(request)
def close_request(self, request):
# No need to close anything.
pass
class ForkingMixIn(object):
"""Mix-in class to handle each request in a new process."""
timeout = 300
active_children = None
max_children = 40
def collect_children(self):
"""Internal routine to wait for children that have exited."""
if self.active_children is None: return
while len(self.active_children) >= self.max_children:
# XXX: This will wait for any child process, not just ones
# spawned by this library. This could confuse other
# libraries that expect to be able to wait for their own
# children.
try:
pid, status = os.waitpid(0, 0)
except os.error:
pid = None
if pid not in self.active_children: continue
self.active_children.remove(pid)
# XXX: This loop runs more system calls than it ought
# to. There should be a way to put the active_children into a
# process group and then use os.waitpid(-pgid) to wait for any
# of that set, but I couldn't find a way to allocate pgids
# that couldn't collide.
for child in self.active_children:
try:
pid, status = os.waitpid(child, os.WNOHANG)
except os.error:
pid = None
if not pid: continue
try:
self.active_children.remove(pid)
except ValueError as e:
raise ValueError('%s. x=%d and list=%r' % (e.message, pid,
self.active_children))
def handle_timeout(self):
"""Wait for zombies after self.timeout seconds of inactivity.
May be extended, do not override.
"""
self.collect_children()
def service_actions(self):
"""Collect the zombie child processes regularly in the ForkingMixIn.
service_actions is called in the BaseServer's serve_forver loop.
"""
self.collect_children()
def process_request(self, request, client_address):
"""Fork a new subprocess to process the request."""
pid = os.fork()
if pid:
# Parent process
if self.active_children is None:
self.active_children = []
self.active_children.append(pid)
self.close_request(request)
return
else:
# Child process.
# This must never return, hence os._exit()!
try:
self.finish_request(request, client_address)
self.shutdown_request(request)
os._exit(0)
except:
try:
self.handle_error(request, client_address)
self.shutdown_request(request)
finally:
os._exit(1)
class ThreadingMixIn(object):
"""Mix-in class to handle each request in a new thread."""
# Decides how threads will act upon termination of the
# main process
daemon_threads = False
def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread.
In addition, exception handling is done here.
"""
try:
self.finish_request(request, client_address)
self.shutdown_request(request)
except:
self.handle_error(request, client_address)
self.shutdown_request(request)
def process_request(self, request, client_address):
"""Start a new thread to process the request."""
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads
t.start()
class ForkingUDPServer(ForkingMixIn, UDPServer): pass
class ForkingTCPServer(ForkingMixIn, TCPServer): pass
class ThreadingUDPServer(ThreadingMixIn, UDPServer): pass
class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass
if hasattr(socket, 'AF_UNIX'):
class UnixStreamServer(TCPServer):
address_family = socket.AF_UNIX
class UnixDatagramServer(UDPServer):
address_family = socket.AF_UNIX
class ThreadingUnixStreamServer(ThreadingMixIn, UnixStreamServer): pass
class ThreadingUnixDatagramServer(ThreadingMixIn, UnixDatagramServer): pass
class BaseRequestHandler(object):
"""Base class for request handler classes.
This class is instantiated for each request to be handled. The
constructor sets the instance variables request, client_address
and server, and then calls the handle() method. To implement a
specific service, all you need to do is to derive a class which
defines a handle() method.
The handle() method can find the request as self.request, the
client address as self.client_address, and the server (in case it
needs access to per-server information) as self.server. Since a
separate instance is created for each request, the handle() method
can define arbitrary other instance variariables.
"""
def __init__(self, request, client_address, server):
self.request = request
self.client_address = client_address
self.server = server
self.setup()
try:
self.handle()
finally:
self.finish()
def setup(self):
pass
def handle(self):
pass
def finish(self):
pass
# The following two classes make it possible to use the same service
# class for stream or datagram servers.
# Each class sets up these instance variables:
# - rfile: a file object from which receives the request is read
# - wfile: a file object to which the reply is written
# When the handle() method returns, wfile is flushed properly
class StreamRequestHandler(BaseRequestHandler):
"""Define self.rfile and self.wfile for stream sockets."""
# Default buffer sizes for rfile, wfile.
# We default rfile to buffered because otherwise it could be
# really slow for large data (a getc() call per byte); we make
# wfile unbuffered because (a) often after a write() we want to
# read and we need to flush the line; (b) big writes to unbuffered
# files are typically optimized by stdio even when big reads
# aren't.
rbufsize = -1
wbufsize = 0
# A timeout to apply to the request socket, if not None.
timeout = None
# Disable nagle algorithm for this socket, if True.
# Use only when wbufsize != 0, to avoid small packets.
disable_nagle_algorithm = False
def setup(self):
self.connection = self.request
if self.timeout is not None:
self.connection.settimeout(self.timeout)
if self.disable_nagle_algorithm:
self.connection.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY, True)
self.rfile = self.connection.makefile('rb', self.rbufsize)
self.wfile = self.connection.makefile('wb', self.wbufsize)
def finish(self):
if not self.wfile.closed:
try:
self.wfile.flush()
except socket.error:
# An final socket error may have occurred here, such as
# the local error ECONNABORTED.
pass
self.wfile.close()
self.rfile.close()
class DatagramRequestHandler(BaseRequestHandler):
# XXX Regrettably, I cannot get this working on Linux;
# s.recvfrom() doesn't return a meaningful client address.
"""Define self.rfile and self.wfile for datagram sockets."""
def setup(self):
from io import BytesIO
self.packet, self.socket = self.request
self.rfile = BytesIO(self.packet)
self.wfile = BytesIO()
def finish(self):
self.socket.sendto(self.wfile.getvalue(), self.client_address)

View file

@ -0,0 +1,9 @@
"""
test package backported for python-future.
Its primary purpose is to allow use of "import test.support" for running
the Python standard library unit tests using the new Python 3 stdlib
import location.
Python 3 renamed test.test_support to test.support.
"""

View file

@ -0,0 +1,36 @@
-----BEGIN RSA PRIVATE KEY-----
MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L
opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH
fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB
AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU
D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA
IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM
oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0
ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/
loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j
oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA
z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq
ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV
q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
Just bad cert data
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L
opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH
fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB
AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU
D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA
IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM
oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0
ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/
loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j
oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA
z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq
ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV
q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
Just bad cert data
-----END CERTIFICATE-----

View file

@ -0,0 +1,40 @@
-----BEGIN RSA PRIVATE KEY-----
Bad Key, though the cert should be OK
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD
VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x
IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT
U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1
NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl
bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m
dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj
aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh
m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8
M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn
fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC
AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb
08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx
CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/
iHkC6gGdBJhogs4=
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
Bad Key, though the cert should be OK
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD
VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x
IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT
U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1
NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl
bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m
dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj
aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh
m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8
M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn
fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC
AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb
08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx
CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/
iHkC6gGdBJhogs4=
-----END CERTIFICATE-----

View file

@ -0,0 +1,9 @@
-----BEGIN DH PARAMETERS-----
MEYCQQD1Kv884bEpQBgRjXyEpwpy1obEAxnIByl6ypUM2Zafq9AKUJsCRtMIPWak
XUGfnHy9iUsiGSa6q6Jew1XpKgVfAgEC
-----END DH PARAMETERS-----
These are the 512 bit DH parameters from "Assigned Number for SKIP Protocols"
(http://www.skip-vpn.org/spec/numbers.html).
See there for how they were generated.
Note that g is not a generator, but this is not a problem since p is a safe prime.

View file

@ -0,0 +1,41 @@
-----BEGIN CERTIFICATE-----
MIIHPTCCBSWgAwIBAgIBADANBgkqhkiG9w0BAQQFADB5MRAwDgYDVQQKEwdSb290
IENBMR4wHAYDVQQLExVodHRwOi8vd3d3LmNhY2VydC5vcmcxIjAgBgNVBAMTGUNB
IENlcnQgU2lnbmluZyBBdXRob3JpdHkxITAfBgkqhkiG9w0BCQEWEnN1cHBvcnRA
Y2FjZXJ0Lm9yZzAeFw0wMzAzMzAxMjI5NDlaFw0zMzAzMjkxMjI5NDlaMHkxEDAO
BgNVBAoTB1Jvb3QgQ0ExHjAcBgNVBAsTFWh0dHA6Ly93d3cuY2FjZXJ0Lm9yZzEi
MCAGA1UEAxMZQ0EgQ2VydCBTaWduaW5nIEF1dGhvcml0eTEhMB8GCSqGSIb3DQEJ
ARYSc3VwcG9ydEBjYWNlcnQub3JnMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIIC
CgKCAgEAziLA4kZ97DYoB1CW8qAzQIxL8TtmPzHlawI229Z89vGIj053NgVBlfkJ
8BLPRoZzYLdufujAWGSuzbCtRRcMY/pnCujW0r8+55jE8Ez64AO7NV1sId6eINm6
zWYyN3L69wj1x81YyY7nDl7qPv4coRQKFWyGhFtkZip6qUtTefWIonvuLwphK42y
fk1WpRPs6tqSnqxEQR5YYGUFZvjARL3LlPdCfgv3ZWiYUQXw8wWRBB0bF4LsyFe7
w2t6iPGwcswlWyCR7BYCEo8y6RcYSNDHBS4CMEK4JZwFaz+qOqfrU0j36NK2B5jc
G8Y0f3/JHIJ6BVgrCFvzOKKrF11myZjXnhCLotLddJr3cQxyYN/Nb5gznZY0dj4k
epKwDpUeb+agRThHqtdB7Uq3EvbXG4OKDy7YCbZZ16oE/9KTfWgu3YtLq1i6L43q
laegw1SJpfvbi1EinbLDvhG+LJGGi5Z4rSDTii8aP8bQUWWHIbEZAWV/RRyH9XzQ
QUxPKZgh/TMfdQwEUfoZd9vUFBzugcMd9Zi3aQaRIt0AUMyBMawSB3s42mhb5ivU
fslfrejrckzzAeVLIL+aplfKkQABi6F1ITe1Yw1nPkZPcCBnzsXWWdsC4PDSy826
YreQQejdIOQpvGQpQsgi3Hia/0PsmBsJUUtaWsJx8cTLc6nloQsCAwEAAaOCAc4w
ggHKMB0GA1UdDgQWBBQWtTIb1Mfz4OaO873SsDrusjkY0TCBowYDVR0jBIGbMIGY
gBQWtTIb1Mfz4OaO873SsDrusjkY0aF9pHsweTEQMA4GA1UEChMHUm9vdCBDQTEe
MBwGA1UECxMVaHR0cDovL3d3dy5jYWNlcnQub3JnMSIwIAYDVQQDExlDQSBDZXJ0
IFNpZ25pbmcgQXV0aG9yaXR5MSEwHwYJKoZIhvcNAQkBFhJzdXBwb3J0QGNhY2Vy
dC5vcmeCAQAwDwYDVR0TAQH/BAUwAwEB/zAyBgNVHR8EKzApMCegJaAjhiFodHRw
czovL3d3dy5jYWNlcnQub3JnL3Jldm9rZS5jcmwwMAYJYIZIAYb4QgEEBCMWIWh0
dHBzOi8vd3d3LmNhY2VydC5vcmcvcmV2b2tlLmNybDA0BglghkgBhvhCAQgEJxYl
aHR0cDovL3d3dy5jYWNlcnQub3JnL2luZGV4LnBocD9pZD0xMDBWBglghkgBhvhC
AQ0ESRZHVG8gZ2V0IHlvdXIgb3duIGNlcnRpZmljYXRlIGZvciBGUkVFIGhlYWQg
b3ZlciB0byBodHRwOi8vd3d3LmNhY2VydC5vcmcwDQYJKoZIhvcNAQEEBQADggIB
ACjH7pyCArpcgBLKNQodgW+JapnM8mgPf6fhjViVPr3yBsOQWqy1YPaZQwGjiHCc
nWKdpIevZ1gNMDY75q1I08t0AoZxPuIrA2jxNGJARjtT6ij0rPtmlVOKTV39O9lg
18p5aTuxZZKmxoGCXJzN600BiqXfEVWqFcofN8CCmHBh22p8lqOOLlQ+TyGpkO/c
gr/c6EWtTZBzCDyUZbAEmXZ/4rzCahWqlwQ3JNgelE5tDlG+1sSPypZt90Pf6DBl
Jzt7u0NDY8RD97LsaMzhGY4i+5jhe1o+ATc7iwiwovOVThrLm82asduycPAtStvY
sONvRUgzEv/+PDIqVPfE94rwiCPCR/5kenHA0R6mY7AHfqQv0wGP3J8rtsYIqQ+T
SCX8Ev2fQtzzxD72V7DX3WnRBnc0CkvSyqD/HMaMyRa+xMwyN2hzXwj7UfdJUzYF
CpUCTPJ5GhD22Dp1nPMd8aINcGeGG7MW9S/lpOt5hvk9C8JzC6WZrG/8Z7jlLwum
GCSNe9FINSkYQKyTYOGWhlC0elnYjyELn8+CkcY7v2vcB5G5l1YjqrZslMZIBjzk
zk6q5PYvCdxTby78dOs6Y5nCpqyJvKeyRKANihDjbPIky/qbn3BHLt4Ui9SyIAmW
omTxJBzcoTWcFbLUvFUufQb1nA5V9FrWk9p2rSVzTMVD
-----END CERTIFICATE-----

View file

@ -0,0 +1,33 @@
-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,1A8D9D2A02EC698A
kJYbfZ8L0sfe9Oty3gw0aloNnY5E8fegRfQLZlNoxTl6jNt0nIwI8kDJ36CZgR9c
u3FDJm/KqrfUoz8vW+qEnWhSG7QPX2wWGPHd4K94Yz/FgrRzZ0DoK7XxXq9gOtVA
AVGQhnz32p+6WhfGsCr9ArXEwRZrTk/FvzEPaU5fHcoSkrNVAGX8IpSVkSDwEDQr
Gv17+cfk99UV1OCza6yKHoFkTtrC+PZU71LomBabivS2Oc4B9hYuSR2hF01wTHP+
YlWNagZOOVtNz4oKK9x9eNQpmfQXQvPPTfusexKIbKfZrMvJoxcm1gfcZ0H/wK6P
6wmXSG35qMOOztCZNtperjs1wzEBXznyK8QmLcAJBjkfarABJX9vBEzZV0OUKhy+
noORFwHTllphbmydLhu6ehLUZMHPhzAS5UN7srtpSN81eerDMy0RMUAwA7/PofX1
94Me85Q8jP0PC9ETdsJcPqLzAPETEYu0ELewKRcrdyWi+tlLFrpE5KT/s5ecbl9l
7B61U4Kfd1PIXc/siINhU3A3bYK+845YyUArUOnKf1kEox7p1RpD7yFqVT04lRTo
cibNKATBusXSuBrp2G6GNuhWEOSafWCKJQAzgCYIp6ZTV2khhMUGppc/2H3CF6cO
zX0KtlPVZC7hLkB6HT8SxYUwF1zqWY7+/XPPdc37MeEZ87Q3UuZwqORLY+Z0hpgt
L5JXBCoklZhCAaN2GqwFLXtGiRSRFGY7xXIhbDTlE65Wv1WGGgDLMKGE1gOz3yAo
2jjG1+yAHJUdE69XTFHSqSkvaloA1W03LdMXZ9VuQJ/ySXCie6ABAQ==
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICVDCCAb2gAwIBAgIJANfHOBkZr8JOMA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV
BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u
IFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xMDEw
MDgyMzAxNTZaFw0yMDEwMDUyMzAxNTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQH
Ew5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZvdW5k
YXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
gYkCgYEA21vT5isq7F68amYuuNpSFlKDPrMUCa4YWYqZRt2OZ+/3NKaZ2xAiSwr7
6MrQF70t5nLbSPpqE5+5VrS58SY+g/sXLiFd6AplH1wJZwh78DofbFYXUggktFMt
pTyiX8jtP66bkcPkDADA089RI1TQR6Ca+n7HFa7c1fabVV6i3zkCAwEAAaMYMBYw
FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBBQUAA4GBAHPctQBEQ4wd
BJ6+JcpIraopLn8BGhbjNWj40mmRqWB/NAWF6M5ne7KpGAu7tLeG4hb1zLaldK8G
lxy2GPSRF6LFS48dpEj2HbMv2nvv6xxalDMJ9+DicWgAKTQ6bcX2j3GUkCR0g/T1
CRlNBAAlvhKzO7Clpf9l0YKBEfraJByX
-----END CERTIFICATE-----

View file

@ -0,0 +1,31 @@
-----BEGIN PRIVATE KEY-----
MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBANtb0+YrKuxevGpm
LrjaUhZSgz6zFAmuGFmKmUbdjmfv9zSmmdsQIksK++jK0Be9LeZy20j6ahOfuVa0
ufEmPoP7Fy4hXegKZR9cCWcIe/A6H2xWF1IIJLRTLaU8ol/I7T+um5HD5AwAwNPP
USNU0Eegmvp+xxWu3NX2m1Veot85AgMBAAECgYA3ZdZ673X0oexFlq7AAmrutkHt
CL7LvwrpOiaBjhyTxTeSNWzvtQBkIU8DOI0bIazA4UreAFffwtvEuPmonDb3F+Iq
SMAu42XcGyVZEl+gHlTPU9XRX7nTOXVt+MlRRRxL6t9GkGfUAXI3XxJDXW3c0vBK
UL9xqD8cORXOfE06rQJBAP8mEX1ERkR64Ptsoe4281vjTlNfIbs7NMPkUnrn9N/Y
BLhjNIfQ3HFZG8BTMLfX7kCS9D593DW5tV4Z9BP/c6cCQQDcFzCcVArNh2JSywOQ
ZfTfRbJg/Z5Lt9Fkngv1meeGNPgIMLN8Sg679pAOOWmzdMO3V706rNPzSVMME7E5
oPIfAkEA8pDddarP5tCvTTgUpmTFbakm0KoTZm2+FzHcnA4jRh+XNTjTOv98Y6Ik
eO5d1ZnKXseWvkZncQgxfdnMqqpj5wJAcNq/RVne1DbYlwWchT2Si65MYmmJ8t+F
0mcsULqjOnEMwf5e+ptq5LzwbyrHZYq5FNk7ocufPv/ZQrcSSC+cFwJBAKvOJByS
x56qyGeZLOQlWS2JS3KJo59XuLFGqcbgN9Om9xFa41Yb4N9NvplFivsvZdw3m1Q/
SPIXQuT8RMPDVNQ=
-----END PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICVDCCAb2gAwIBAgIJANfHOBkZr8JOMA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV
BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u
IFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xMDEw
MDgyMzAxNTZaFw0yMDEwMDUyMzAxNTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQH
Ew5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZvdW5k
YXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
gYkCgYEA21vT5isq7F68amYuuNpSFlKDPrMUCa4YWYqZRt2OZ+/3NKaZ2xAiSwr7
6MrQF70t5nLbSPpqE5+5VrS58SY+g/sXLiFd6AplH1wJZwh78DofbFYXUggktFMt
pTyiX8jtP66bkcPkDADA089RI1TQR6Ca+n7HFa7c1fabVV6i3zkCAwEAAaMYMBYw
FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBBQUAA4GBAHPctQBEQ4wd
BJ6+JcpIraopLn8BGhbjNWj40mmRqWB/NAWF6M5ne7KpGAu7tLeG4hb1zLaldK8G
lxy2GPSRF6LFS48dpEj2HbMv2nvv6xxalDMJ9+DicWgAKTQ6bcX2j3GUkCR0g/T1
CRlNBAAlvhKzO7Clpf9l0YKBEfraJByX
-----END CERTIFICATE-----

View file

@ -0,0 +1,31 @@
-----BEGIN PRIVATE KEY-----
MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBAJnsJZVrppL+W5I9
zGQrrawWwE5QJpBK9nWw17mXrZ03R1cD9BamLGivVISbPlRlAVnZBEyh1ATpsB7d
CUQ+WHEvALquvx4+Yw5l+fXeiYRjrLRBYZuVy8yNtXzU3iWcGObcYRkUdiXdOyP7
sLF2YZHRvQZpzgDBKkrraeQ81w21AgMBAAECgYBEm7n07FMHWlE+0kT0sXNsLYfy
YE+QKZnJw9WkaDN+zFEEPELkhZVt5BjsMraJr6v2fIEqF0gGGJPkbenffVq2B5dC
lWUOxvJHufMK4sM3Cp6s/gOp3LP+QkzVnvJSfAyZU6l+4PGX5pLdUsXYjPxgzjzL
S36tF7/2Uv1WePyLUQJBAMsPhYzUXOPRgmbhcJiqi9A9c3GO8kvSDYTCKt3VMnqz
HBn6MQ4VQasCD1F+7jWTI0FU/3vdw8non/Fj8hhYqZcCQQDCDRdvmZqDiZnpMqDq
L6ZSrLTVtMvZXZbgwForaAD9uHj51TME7+eYT7EG2YCgJTXJ4YvRJEnPNyskwdKt
vTSTAkEAtaaN/vyemEJ82BIGStwONNw0ILsSr5cZ9tBHzqiA/tipY+e36HRFiXhP
QcU9zXlxyWkDH8iz9DSAmE2jbfoqwwJANlMJ65E543cjIlitGcKLMnvtCCLcKpb7
xSG0XJB6Lo11OKPJ66jp0gcFTSCY1Lx2CXVd+gfJrfwI1Pp562+bhwJBAJ9IfDPU
R8OpO9v1SGd8x33Owm7uXOpB9d63/T70AD1QOXjKUC4eXYbt0WWfWuny/RNPRuyh
w7DXSfUF+kPKolU=
-----END PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICXTCCAcagAwIBAgIJAIO3upAG445fMA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV
BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u
IFNvZnR3YXJlIEZvdW5kYXRpb24xFTATBgNVBAMTDGZha2Vob3N0bmFtZTAeFw0x
MDEwMDkxNTAxMDBaFw0yMDEwMDYxNTAxMDBaMGIxCzAJBgNVBAYTAlhZMRcwFQYD
VQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZv
dW5kYXRpb24xFTATBgNVBAMTDGZha2Vob3N0bmFtZTCBnzANBgkqhkiG9w0BAQEF
AAOBjQAwgYkCgYEAmewllWumkv5bkj3MZCutrBbATlAmkEr2dbDXuZetnTdHVwP0
FqYsaK9UhJs+VGUBWdkETKHUBOmwHt0JRD5YcS8Auq6/Hj5jDmX59d6JhGOstEFh
m5XLzI21fNTeJZwY5txhGRR2Jd07I/uwsXZhkdG9BmnOAMEqSutp5DzXDbUCAwEA
AaMbMBkwFwYDVR0RBBAwDoIMZmFrZWhvc3RuYW1lMA0GCSqGSIb3DQEBBQUAA4GB
AH+iMClLLGSaKWgwXsmdVo4FhTZZHo8Uprrtg3N9FxEeE50btpDVQysgRt5ias3K
m+bME9zbKwvbVWD5zZdjus4pDgzwF/iHyccL8JyYhxOvS/9zmvAtFXj/APIIbZFp
IT75d9f88ScIGEtknZQejnrdhB64tYki/EqluiuKBqKD
-----END CERTIFICATE-----

Some files were not shown because too many files have changed in this diff Show more