Source code for sheepdog.client

# Sheepdog
# Copyright 2013 Adam Greig
#
# Released under the MIT license. See LICENSE file for details.

"""
Sheepdog's clientside code.

This code is typically only run on the worker, and this file is currently
only used by pasting it into a job file (as workers don't generally have
sheepdog itself installed).
"""

import re
import time
import json
import base64
import resource
import traceback

try:
    from urllib.error import URLError
    from urllib.parse import urlencode
    from urllib.request import urlopen, Request
except ImportError:
    from urllib import urlencode
    from urllib2 import urlopen, Request, URLError

try:
    deserialise_function
except NameError:
    from sheepdog.serialisation import (deserialise_function,
                                        deserialise_arg,
                                        deserialise_namespace,
                                        serialise_pickle)


[docs]class Client: """Find out what to do, do it, report back.""" HTTP_RETRIES = 10 def __init__(self, url, password, request_id, job_index): self.url = url self.password = password self.request_id = request_id self.job_index = job_index userpass = ("sheepdog:" + self.password).encode() authstr = "Basic " + base64.b64encode(userpass).decode() self.authhdr = {"Authorization": authstr}
[docs] def set_memlimit(self, fname=__file__): with open(fname) as f: contents = f.read() match = re.search(r"^#\$ -l mem_grab=([0-9]+)([kmgtKMGT]?)$", contents, re.MULTILINE) if match: units = match.group(2).upper() scale = {'K': 1024, 'M': 1024*1024, 'G': 1024*1024*1024, 'T': 1024*1024*1024*1024, '': 1}[units] limit = int(match.group(1)) * scale print("Setting RLIMIT_AS to {}".format(limit)) resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
[docs] def get_details(self): """Retrieve the function to run and arguments to run with from the server. """ url = self.url + "?request_id={0}&job_index={1}" url = url.format(self.request_id, self.job_index) print("Fetching URL: {}".format(url)) req = Request(url, headers=self.authhdr) tries = 0 while tries < self.HTTP_RETRIES: try: response = urlopen(req) break except URLError: tries += 1 time.sleep(1) continue if tries == self.HTTP_RETRIES: raise RuntimeError("Could not connect to server.") result = json.loads(response.read().decode()) self.args = deserialise_arg(result['args']) self.ns = deserialise_namespace(result['ns']) self.func = deserialise_function(result['func'], self.ns)
[docs] def run(self): """Run the downloaded function, storing the result.""" if not hasattr(self, 'func') or not hasattr(self, 'args'): raise RuntimeError("Must call `get_details` before `run`.") try: self.set_memlimit() self.result = self.func(*self.args) print(self.result) except: self._submit_error(traceback.format_exc())
[docs] def submit_results(self): if not hasattr(self, 'result'): raise RuntimeError("Must call `run` before `submit_results`.") result = serialise_pickle(self.result) self._submit(self.url, dict(result=result))
def _submit_error(self, error): self._submit(self.url + "error", dict(error=str(error))) def _submit(self, url, data): data.update( {"request_id": self.request_id, "job_index": self.job_index}) encoded_data = urlencode(data).encode() req = Request(url, data=encoded_data, headers=self.authhdr) tries = 0 while tries < self.HTTP_RETRIES: try: urlopen(req) break except URLError: tries += 1 time.sleep(1) continue if tries == self.HTTP_RETRIES: raise RuntimeError("Could not submit to server.")
[docs] def go(self): """Call get_details(), run(), submit_results(). Just for convenience. """ self.get_details() self.run() if hasattr(self, 'result'): self.submit_results()