diff --git a/tasktiger/__init__.py b/tasktiger/__init__.py index a5973e47..d44cc2f5 100644 --- a/tasktiger/__init__.py +++ b/tasktiger/__init__.py @@ -132,10 +132,17 @@ def __init__(self, connection=None, config=None, setup_structlog=False): # If non-empty, a worker only processeses the given queues. 'ONLY_QUEUES': [], + + 'SERIALIZER': json.dumps, + + 'DESERIALIZER': json.loads } if config: self.config.update(config) + self._serialize_data = self.config['SERIALIZER'] + self._deserialize_data = self.config['DESERIALIZER'] + self.connection = connection or redis.Redis(decode_responses=True) self.scripts = RedisScripts(self.connection) diff --git a/tasktiger/task.py b/tasktiger/task.py index 87e6be8b..0c0df26b 100644 --- a/tasktiger/task.py +++ b/tasktiger/task.py @@ -1,5 +1,4 @@ import datetime -import json import redis import time @@ -12,7 +11,7 @@ class Task(object): def __init__(self, tiger, func=None, args=None, kwargs=None, queue=None, hard_timeout=None, unique=None, lock=None, lock_key=None, retry=None, retry_on=None, retry_method=None, - _data=None, _state=None, _ts=None, _executions=None): + _data=None, _state=None, _ts=None, _executions=None, deserialize=json.loads, serialize=json.dumps): """ Queues a task. See README.rst for an explanation of the options. """ @@ -269,7 +268,7 @@ def delay(self, when=None): # When using ALWAYS_EAGER, make sure we have serialized the task to # ensure there are no serialization errors. - serialized_task = json.dumps(self._data) + serialized_task = self.tiger._serialize_data(self._data) if tiger.config['ALWAYS_EAGER'] and state == QUEUED: return self.execute() @@ -330,8 +329,8 @@ def from_id(self, tiger, queue, state, task_id, load_executions=0): serialized_executions = [] # XXX: No timestamp for now if serialized_data: - data = json.loads(serialized_data) - executions = [json.loads(e) for e in serialized_executions if e] + data = tiger._deserialize_data(serialized_data) + executions = [tiger._deserialize_data(e) for e in serialized_executions if e] return Task(tiger, queue=queue, _data=data, _state=state, _executions=executions) else: @@ -369,8 +368,8 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000, results = pipeline.execute() for serialized_data, serialized_executions, ts in zip(results[0], results[1:], tss): - data = json.loads(serialized_data) - executions = [json.loads(e) for e in serialized_executions if e] + data = tiger._deserialize_data(serialized_data) + executions = [tiger._deserialize_data.loads(e) for e in serialized_executions if e] task = Task(tiger, queue=queue, _data=data, _state=state, _ts=ts, _executions=executions) @@ -379,7 +378,7 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000, else: data = tiger.connection.mget([tiger._key('task', item[0]) for item in items]) for serialized_data, ts in zip(data, tss): - data = json.loads(serialized_data) + data = tiger._deserialize_data(serialized_data) task = Task(tiger, queue=queue, _data=data, _state=state, _ts=ts) tasks.append(task) diff --git a/tasktiger/worker.py b/tasktiger/worker.py index 890a9013..4920afb1 100644 --- a/tasktiger/worker.py +++ b/tasktiger/worker.py @@ -1,6 +1,5 @@ from collections import OrderedDict import errno -import json import os import random import select @@ -242,7 +241,7 @@ def _execute_forked(self, tasks, log): ''.join(traceback.format_exception(*exc_info)) if exc_info != (None, None, None) else None execution['success'] = success execution['host'] = socket.gethostname() - serialized_execution = json.dumps(execution) + serialized_execution = self.tiger._serialize_data(execution) for task in tasks: self.connection.rpush(self._key('task', task.id, 'executions'), serialized_execution) @@ -359,7 +358,7 @@ def _process_from_queue(self, queue): tasks = [] for task_id, serialized_task in zip(task_ids, serialized_tasks): if serialized_task: - task_data = json.loads(serialized_task) + task_data = self.tiger._deserialize_data(serialized_task) else: task_data = {} task = Task(self.tiger, queue=queue, _data=task_data, @@ -492,7 +491,7 @@ def _mark_done(): self._key('task', task.id, 'executions'), -1) if execution: - execution = json.loads(execution) + execution = self.tiger._deserialize_data(execution) if execution.get('retry'): if 'retry_method' in execution: diff --git a/tests/__init__.py b/tests/__init__.py index c9d8b08d..7c7ce0b1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,4 @@ +import pickle import datetime import json from multiprocessing import Pool @@ -32,7 +33,7 @@ def _ensure_queue(typ, data): for name, n in data.items(): task_ids = self.conn.zrange('t:%s:%s' % (typ, name), 0, -1) self.assertEqual(len(task_ids), n) - ret[name] = [json.loads(self.conn.get('t:task:%s' % task_id)) + ret[name] = [self.tiger._deserialize_data(self.conn.get('t:task:%s' % task_id)) for task_id in task_ids] self.assertEqual(list(task['id'] for task in ret[name]), task_ids) @@ -46,6 +47,22 @@ def _ensure_queue(typ, data): 'scheduled': _ensure_queue('scheduled', scheduled), } + +class CustomSerializerTestCase(BaseTestCase): + def setUp(self): + self.tiger = get_tiger(SERIALIZER=lambda x: pickle.dumps(x).decode('latin1'), DESERIALIZER=lambda x: pickle.loads(x.encode('latin1'))) + self.conn = self.tiger.connection + self.conn.flushdb() + + def test_simple_task(self): + self.tiger.delay(simple_task, queue='custom_ser') + queues = self._ensure_queues(queued={'custom_ser': 1}) + task = queues['queued']['custom_ser'][0] + Worker(self.tiger).run(once=True) + self._ensure_queues(queued={'custom_ser': 0}) + self.assertFalse(self.conn.exists('t:task:%s' % task['id'])) + + class TestCase(BaseTestCase): """ TaskTiger main test cases. diff --git a/tests/utils.py b/tests/utils.py index 627dd397..770a5cd4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ from .config import * -def get_tiger(): +def get_tiger(**kwargs): """ Sets up logging and returns a new tasktiger instance. """ @@ -15,7 +15,7 @@ def get_tiger(): ) logging.basicConfig(format='%(message)s') conn = redis.Redis(db=TEST_DB, decode_responses=True) - tiger = TaskTiger(connection=conn, config={ + config = { # We need this 0 here so we don't pick up scheduled tasks when # doing a single worker run. 'SELECT_TIMEOUT': 0, @@ -27,7 +27,9 @@ def get_tiger(): 'BATCH_QUEUES': { 'batch': 3, } - }) + } + config.update(kwargs) + tiger = TaskTiger(connection=conn, config=config) tiger.log.setLevel(logging.CRITICAL) return tiger