1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
3
 
# This program is free software; you can redistribute it and/or modify
 
4
 
# it under the terms of the GNU General Public License as published by
 
5
 
# the Free Software Foundation; either version 2 of the License, or
 
6
 
# (at your option) any later version.
 
8
 
# This program is distributed in the hope that it will be useful,
 
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
 
# GNU General Public License for more details.
 
13
 
# You should have received a copy of the GNU General Public License
 
14
 
# along with this program; if not, write to the Free Software
 
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
17
 
"""Wire-level encoding and decoding of requests and responses for the smart
 
22
 
from cStringIO import StringIO
 
35
 
from bzrlib.smart import message, request
 
36
 
from bzrlib.trace import log_exception_quietly, mutter
 
37
 
from bzrlib.bencode import bdecode_as_tuple, bencode
 
40
 
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
41
 
# responses to identify the protocol version being used. (There are no version
 
42
 
# one strings because that version doesn't send any).
 
43
 
REQUEST_VERSION_TWO = 'bzr request 2\n'
 
44
 
RESPONSE_VERSION_TWO = 'bzr response 2\n'
 
46
 
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
 
47
 
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
 
50
 
def _recv_tuple(from_file):
 
51
 
    req_line = from_file.readline()
 
52
 
    return _decode_tuple(req_line)
 
55
 
def _decode_tuple(req_line):
 
56
 
    if req_line is None or req_line == '':
 
58
 
    if req_line[-1] != '\n':
 
59
 
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
60
 
    return tuple(req_line[:-1].split('\x01'))
 
63
 
def _encode_tuple(args):
 
64
 
    """Encode the tuple args to a bytestream."""
 
65
 
    joined = '\x01'.join(args) + '\n'
 
66
 
    if type(joined) is unicode:
 
67
 
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
68
 
        mutter('response args contain unicode, should be only bytes: %r',
 
70
 
        joined = joined.encode('ascii')
 
74
 
class Requester(object):
 
75
 
    """Abstract base class for an object that can issue requests on a smart
 
79
 
    def call(self, *args):
 
80
 
        """Make a remote call.
 
82
 
        :param args: the arguments of this call.
 
84
 
        raise NotImplementedError(self.call)
 
86
 
    def call_with_body_bytes(self, args, body):
 
87
 
        """Make a remote call with a body.
 
89
 
        :param args: the arguments of this call.
 
91
 
        :param body: the body to send with the request.
 
93
 
        raise NotImplementedError(self.call_with_body_bytes)
 
95
 
    def call_with_body_readv_array(self, args, body):
 
96
 
        """Make a remote call with a readv array.
 
98
 
        :param args: the arguments of this call.
 
99
 
        :type body: iterable of (start, length) tuples.
 
100
 
        :param body: the readv ranges to send with this request.
 
102
 
        raise NotImplementedError(self.call_with_body_readv_array)
 
104
 
    def set_headers(self, headers):
 
105
 
        raise NotImplementedError(self.set_headers)
 
108
 
class SmartProtocolBase(object):
 
109
 
    """Methods common to client and server"""
 
111
 
    # TODO: this only actually accomodates a single block; possibly should
 
112
 
    # support multiple chunks?
 
113
 
    def _encode_bulk_data(self, body):
 
114
 
        """Encode body as a bulk data chunk."""
 
115
 
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
117
 
    def _serialise_offsets(self, offsets):
 
118
 
        """Serialise a readv offset list."""
 
120
 
        for start, length in offsets:
 
121
 
            txt.append('%d,%d' % (start, length))
 
122
 
        return '\n'.join(txt)
 
125
 
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
126
 
    """Server-side encoding and decoding logic for smart version 1."""
 
128
 
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
130
 
        self._backing_transport = backing_transport
 
131
 
        self._root_client_path = root_client_path
 
132
 
        self._jail_root = jail_root
 
133
 
        self.unused_data = ''
 
134
 
        self._finished = False
 
136
 
        self._has_dispatched = False
 
138
 
        self._body_decoder = None
 
139
 
        self._write_func = write_func
 
141
 
    def accept_bytes(self, bytes):
 
142
 
        """Take bytes, and advance the internal state machine appropriately.
 
144
 
        :param bytes: must be a byte string
 
146
 
        if not isinstance(bytes, str):
 
147
 
            raise ValueError(bytes)
 
148
 
        self.in_buffer += bytes
 
149
 
        if not self._has_dispatched:
 
150
 
            if '\n' not in self.in_buffer:
 
151
 
                # no command line yet
 
153
 
            self._has_dispatched = True
 
155
 
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
157
 
                req_args = _decode_tuple(first_line)
 
158
 
                self.request = request.SmartServerRequestHandler(
 
159
 
                    self._backing_transport, commands=request.request_handlers,
 
160
 
                    root_client_path=self._root_client_path,
 
161
 
                    jail_root=self._jail_root)
 
162
 
                self.request.args_received(req_args)
 
163
 
                if self.request.finished_reading:
 
165
 
                    self.unused_data = self.in_buffer
 
167
 
                    self._send_response(self.request.response)
 
168
 
            except KeyboardInterrupt:
 
170
 
            except errors.UnknownSmartMethod, err:
 
171
 
                protocol_error = errors.SmartProtocolError(
 
172
 
                    "bad request %r" % (err.verb,))
 
173
 
                failure = request.FailedSmartServerResponse(
 
174
 
                    ('error', str(protocol_error)))
 
175
 
                self._send_response(failure)
 
177
 
            except Exception, exception:
 
178
 
                # everything else: pass to client, flush, and quit
 
179
 
                log_exception_quietly()
 
180
 
                self._send_response(request.FailedSmartServerResponse(
 
181
 
                    ('error', str(exception))))
 
184
 
        if self._has_dispatched:
 
186
 
                # nothing to do.XXX: this routine should be a single state
 
188
 
                self.unused_data += self.in_buffer
 
191
 
            if self._body_decoder is None:
 
192
 
                self._body_decoder = LengthPrefixedBodyDecoder()
 
193
 
            self._body_decoder.accept_bytes(self.in_buffer)
 
194
 
            self.in_buffer = self._body_decoder.unused_data
 
195
 
            body_data = self._body_decoder.read_pending_data()
 
196
 
            self.request.accept_body(body_data)
 
197
 
            if self._body_decoder.finished_reading:
 
198
 
                self.request.end_of_body()
 
199
 
                if not self.request.finished_reading:
 
200
 
                    raise AssertionError("no more body, request not finished")
 
201
 
            if self.request.response is not None:
 
202
 
                self._send_response(self.request.response)
 
203
 
                self.unused_data = self.in_buffer
 
206
 
                if self.request.finished_reading:
 
207
 
                    raise AssertionError(
 
208
 
                        "no response and we have finished reading.")
 
210
 
    def _send_response(self, response):
 
211
 
        """Send a smart server response down the output stream."""
 
213
 
            raise AssertionError('response already sent')
 
216
 
        self._finished = True
 
217
 
        self._write_protocol_version()
 
218
 
        self._write_success_or_failure_prefix(response)
 
219
 
        self._write_func(_encode_tuple(args))
 
221
 
            if not isinstance(body, str):
 
222
 
                raise ValueError(body)
 
223
 
            bytes = self._encode_bulk_data(body)
 
224
 
            self._write_func(bytes)
 
226
 
    def _write_protocol_version(self):
 
227
 
        """Write any prefixes this protocol requires.
 
229
 
        Version one doesn't send protocol versions.
 
232
 
    def _write_success_or_failure_prefix(self, response):
 
233
 
        """Write the protocol specific success/failure prefix.
 
235
 
        For SmartServerRequestProtocolOne this is omitted but we
 
236
 
        call is_successful to ensure that the response is valid.
 
238
 
        response.is_successful()
 
240
 
    def next_read_size(self):
 
243
 
        if self._body_decoder is None:
 
246
 
            return self._body_decoder.next_read_size()
 
249
 
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
250
 
    r"""Version two of the server side of the smart protocol.
 
252
 
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
255
 
    response_marker = RESPONSE_VERSION_TWO
 
256
 
    request_marker = REQUEST_VERSION_TWO
 
258
 
    def _write_success_or_failure_prefix(self, response):
 
259
 
        """Write the protocol specific success/failure prefix."""
 
260
 
        if response.is_successful():
 
261
 
            self._write_func('success\n')
 
263
 
            self._write_func('failed\n')
 
265
 
    def _write_protocol_version(self):
 
266
 
        r"""Write any prefixes this protocol requires.
 
268
 
        Version two sends the value of RESPONSE_VERSION_TWO.
 
270
 
        self._write_func(self.response_marker)
 
272
 
    def _send_response(self, response):
 
273
 
        """Send a smart server response down the output stream."""
 
275
 
            raise AssertionError('response already sent')
 
276
 
        self._finished = True
 
277
 
        self._write_protocol_version()
 
278
 
        self._write_success_or_failure_prefix(response)
 
279
 
        self._write_func(_encode_tuple(response.args))
 
280
 
        if response.body is not None:
 
281
 
            if not isinstance(response.body, str):
 
282
 
                raise AssertionError('body must be a str')
 
283
 
            if not (response.body_stream is None):
 
284
 
                raise AssertionError(
 
285
 
                    'body_stream and body cannot both be set')
 
286
 
            bytes = self._encode_bulk_data(response.body)
 
287
 
            self._write_func(bytes)
 
288
 
        elif response.body_stream is not None:
 
289
 
            _send_stream(response.body_stream, self._write_func)
 
292
 
def _send_stream(stream, write_func):
 
293
 
    write_func('chunked\n')
 
294
 
    _send_chunks(stream, write_func)
 
298
 
def _send_chunks(stream, write_func):
 
300
 
        if isinstance(chunk, str):
 
301
 
            bytes = "%x\n%s" % (len(chunk), chunk)
 
303
 
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
305
 
            _send_chunks(chunk.args, write_func)
 
308
 
            raise errors.BzrError(
 
309
 
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
313
 
class _NeedMoreBytes(Exception):
 
314
 
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
318
 
    def __init__(self, count=None):
 
321
 
        :param count: the total number of bytes needed by the current state.
 
322
 
            May be None if the number of bytes needed is unknown.
 
327
 
class _StatefulDecoder(object):
 
328
 
    """Base class for writing state machines to decode byte streams.
 
330
 
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
331
 
    and, if appropriate, updates self.state_accept to a different function.
 
332
 
    accept_bytes will call state_accept as often as necessary to make sure the
 
333
 
    state machine has progressed as far as possible before it returns.
 
335
 
    See ProtocolThreeDecoder for an example subclass.
 
339
 
        self.finished_reading = False
 
340
 
        self._in_buffer_list = []
 
341
 
        self._in_buffer_len = 0
 
342
 
        self.unused_data = ''
 
343
 
        self.bytes_left = None
 
344
 
        self._number_needed_bytes = None
 
346
 
    def _get_in_buffer(self):
 
347
 
        if len(self._in_buffer_list) == 1:
 
348
 
            return self._in_buffer_list[0]
 
349
 
        in_buffer = ''.join(self._in_buffer_list)
 
350
 
        if len(in_buffer) != self._in_buffer_len:
 
351
 
            raise AssertionError(
 
352
 
                "Length of buffer did not match expected value: %s != %s"
 
353
 
                % self._in_buffer_len, len(in_buffer))
 
354
 
        self._in_buffer_list = [in_buffer]
 
357
 
    def _get_in_bytes(self, count):
 
358
 
        """Grab X bytes from the input_buffer.
 
360
 
        Callers should have already checked that self._in_buffer_len is >
 
361
 
        count. Note, this does not consume the bytes from the buffer. The
 
362
 
        caller will still need to call _get_in_buffer() and then
 
363
 
        _set_in_buffer() if they actually need to consume the bytes.
 
365
 
        # check if we can yield the bytes from just the first entry in our list
 
366
 
        if len(self._in_buffer_list) == 0:
 
367
 
            raise AssertionError('Callers must be sure we have buffered bytes'
 
368
 
                ' before calling _get_in_bytes')
 
369
 
        if len(self._in_buffer_list[0]) > count:
 
370
 
            return self._in_buffer_list[0][:count]
 
371
 
        # We can't yield it from the first buffer, so collapse all buffers, and
 
373
 
        in_buf = self._get_in_buffer()
 
374
 
        return in_buf[:count]
 
376
 
    def _set_in_buffer(self, new_buf):
 
377
 
        if new_buf is not None:
 
378
 
            self._in_buffer_list = [new_buf]
 
379
 
            self._in_buffer_len = len(new_buf)
 
381
 
            self._in_buffer_list = []
 
382
 
            self._in_buffer_len = 0
 
384
 
    def accept_bytes(self, bytes):
 
385
 
        """Decode as much of bytes as possible.
 
387
 
        If 'bytes' contains too much data it will be appended to
 
390
 
        finished_reading will be set when no more data is required.  Further
 
391
 
        data will be appended to self.unused_data.
 
393
 
        # accept_bytes is allowed to change the state
 
394
 
        self._number_needed_bytes = None
 
395
 
        # lsprof puts a very large amount of time on this specific call for
 
397
 
        self._in_buffer_list.append(bytes)
 
398
 
        self._in_buffer_len += len(bytes)
 
400
 
            # Run the function for the current state.
 
401
 
            current_state = self.state_accept
 
403
 
            while current_state != self.state_accept:
 
404
 
                # The current state has changed.  Run the function for the new
 
405
 
                # current state, so that it can:
 
406
 
                #   - decode any unconsumed bytes left in a buffer, and
 
407
 
                #   - signal how many more bytes are expected (via raising
 
409
 
                current_state = self.state_accept
 
411
 
        except _NeedMoreBytes, e:
 
412
 
            self._number_needed_bytes = e.count
 
415
 
class ChunkedBodyDecoder(_StatefulDecoder):
 
416
 
    """Decoder for chunked body data.
 
418
 
    This is very similar the HTTP's chunked encoding.  See the description of
 
419
 
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
423
 
        _StatefulDecoder.__init__(self)
 
424
 
        self.state_accept = self._state_accept_expecting_header
 
425
 
        self.chunk_in_progress = None
 
426
 
        self.chunks = collections.deque()
 
428
 
        self.error_in_progress = None
 
430
 
    def next_read_size(self):
 
431
 
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
432
 
        # end-of-body marker is 4 bytes: 'END\n'.
 
433
 
        if self.state_accept == self._state_accept_reading_chunk:
 
434
 
            # We're expecting more chunk content.  So we're expecting at least
 
435
 
            # the rest of this chunk plus an END chunk.
 
436
 
            return self.bytes_left + 4
 
437
 
        elif self.state_accept == self._state_accept_expecting_length:
 
438
 
            if self._in_buffer_len == 0:
 
439
 
                # We're expecting a chunk length.  There's at least two bytes
 
440
 
                # left: a digit plus '\n'.
 
443
 
                # We're in the middle of reading a chunk length.  So there's at
 
444
 
                # least one byte left, the '\n' that terminates the length.
 
446
 
        elif self.state_accept == self._state_accept_reading_unused:
 
448
 
        elif self.state_accept == self._state_accept_expecting_header:
 
449
 
            return max(0, len('chunked\n') - self._in_buffer_len)
 
451
 
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
453
 
    def read_next_chunk(self):
 
455
 
            return self.chunks.popleft()
 
459
 
    def _extract_line(self):
 
460
 
        in_buf = self._get_in_buffer()
 
461
 
        pos = in_buf.find('\n')
 
463
 
            # We haven't read a complete line yet, so request more bytes before
 
465
 
            raise _NeedMoreBytes(1)
 
467
 
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
468
 
        self._set_in_buffer(in_buf[pos+1:])
 
472
 
        self.unused_data = self._get_in_buffer()
 
473
 
        self._in_buffer_list = []
 
474
 
        self._in_buffer_len = 0
 
475
 
        self.state_accept = self._state_accept_reading_unused
 
477
 
            error_args = tuple(self.error_in_progress)
 
478
 
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
479
 
            self.error_in_progress = None
 
480
 
        self.finished_reading = True
 
482
 
    def _state_accept_expecting_header(self):
 
483
 
        prefix = self._extract_line()
 
484
 
        if prefix == 'chunked':
 
485
 
            self.state_accept = self._state_accept_expecting_length
 
487
 
            raise errors.SmartProtocolError(
 
488
 
                'Bad chunked body header: "%s"' % (prefix,))
 
490
 
    def _state_accept_expecting_length(self):
 
491
 
        prefix = self._extract_line()
 
494
 
            self.error_in_progress = []
 
495
 
            self._state_accept_expecting_length()
 
497
 
        elif prefix == 'END':
 
498
 
            # We've read the end-of-body marker.
 
499
 
            # Any further bytes are unused data, including the bytes left in
 
504
 
            self.bytes_left = int(prefix, 16)
 
505
 
            self.chunk_in_progress = ''
 
506
 
            self.state_accept = self._state_accept_reading_chunk
 
508
 
    def _state_accept_reading_chunk(self):
 
509
 
        in_buf = self._get_in_buffer()
 
510
 
        in_buffer_len = len(in_buf)
 
511
 
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
512
 
        self._set_in_buffer(in_buf[self.bytes_left:])
 
513
 
        self.bytes_left -= in_buffer_len
 
514
 
        if self.bytes_left <= 0:
 
515
 
            # Finished with chunk
 
516
 
            self.bytes_left = None
 
518
 
                self.error_in_progress.append(self.chunk_in_progress)
 
520
 
                self.chunks.append(self.chunk_in_progress)
 
521
 
            self.chunk_in_progress = None
 
522
 
            self.state_accept = self._state_accept_expecting_length
 
524
 
    def _state_accept_reading_unused(self):
 
525
 
        self.unused_data += self._get_in_buffer()
 
526
 
        self._in_buffer_list = []
 
529
 
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
530
 
    """Decodes the length-prefixed bulk data."""
 
533
 
        _StatefulDecoder.__init__(self)
 
534
 
        self.state_accept = self._state_accept_expecting_length
 
535
 
        self.state_read = self._state_read_no_data
 
537
 
        self._trailer_buffer = ''
 
539
 
    def next_read_size(self):
 
540
 
        if self.bytes_left is not None:
 
541
 
            # Ideally we want to read all the remainder of the body and the
 
543
 
            return self.bytes_left + 5
 
544
 
        elif self.state_accept == self._state_accept_reading_trailer:
 
545
 
            # Just the trailer left
 
546
 
            return 5 - len(self._trailer_buffer)
 
547
 
        elif self.state_accept == self._state_accept_expecting_length:
 
548
 
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
552
 
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
555
 
    def read_pending_data(self):
 
556
 
        """Return any pending data that has been decoded."""
 
557
 
        return self.state_read()
 
559
 
    def _state_accept_expecting_length(self):
 
560
 
        in_buf = self._get_in_buffer()
 
561
 
        pos = in_buf.find('\n')
 
564
 
        self.bytes_left = int(in_buf[:pos])
 
565
 
        self._set_in_buffer(in_buf[pos+1:])
 
566
 
        self.state_accept = self._state_accept_reading_body
 
567
 
        self.state_read = self._state_read_body_buffer
 
569
 
    def _state_accept_reading_body(self):
 
570
 
        in_buf = self._get_in_buffer()
 
572
 
        self.bytes_left -= len(in_buf)
 
573
 
        self._set_in_buffer(None)
 
574
 
        if self.bytes_left <= 0:
 
576
 
            if self.bytes_left != 0:
 
577
 
                self._trailer_buffer = self._body[self.bytes_left:]
 
578
 
                self._body = self._body[:self.bytes_left]
 
579
 
            self.bytes_left = None
 
580
 
            self.state_accept = self._state_accept_reading_trailer
 
582
 
    def _state_accept_reading_trailer(self):
 
583
 
        self._trailer_buffer += self._get_in_buffer()
 
584
 
        self._set_in_buffer(None)
 
585
 
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
586
 
        # a ProtocolViolation exception?
 
587
 
        if self._trailer_buffer.startswith('done\n'):
 
588
 
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
589
 
            self.state_accept = self._state_accept_reading_unused
 
590
 
            self.finished_reading = True
 
592
 
    def _state_accept_reading_unused(self):
 
593
 
        self.unused_data += self._get_in_buffer()
 
594
 
        self._set_in_buffer(None)
 
596
 
    def _state_read_no_data(self):
 
599
 
    def _state_read_body_buffer(self):
 
605
 
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
606
 
                                    message.ResponseHandler):
 
607
 
    """The client-side protocol for smart version 1."""
 
609
 
    def __init__(self, request):
 
610
 
        """Construct a SmartClientRequestProtocolOne.
 
612
 
        :param request: A SmartClientMediumRequest to serialise onto and
 
615
 
        self._request = request
 
616
 
        self._body_buffer = None
 
617
 
        self._request_start_time = None
 
618
 
        self._last_verb = None
 
621
 
    def set_headers(self, headers):
 
622
 
        self._headers = dict(headers)
 
624
 
    def call(self, *args):
 
625
 
        if 'hpss' in debug.debug_flags:
 
626
 
            mutter('hpss call:   %s', repr(args)[1:-1])
 
627
 
            if getattr(self._request._medium, 'base', None) is not None:
 
628
 
                mutter('             (to %s)', self._request._medium.base)
 
629
 
            self._request_start_time = osutils.timer_func()
 
630
 
        self._write_args(args)
 
631
 
        self._request.finished_writing()
 
632
 
        self._last_verb = args[0]
 
634
 
    def call_with_body_bytes(self, args, body):
 
635
 
        """Make a remote call of args with body bytes 'body'.
 
637
 
        After calling this, call read_response_tuple to find the result out.
 
639
 
        if 'hpss' in debug.debug_flags:
 
640
 
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
641
 
            if getattr(self._request._medium, '_path', None) is not None:
 
642
 
                mutter('                  (to %s)', self._request._medium._path)
 
643
 
            mutter('              %d bytes', len(body))
 
644
 
            self._request_start_time = osutils.timer_func()
 
645
 
            if 'hpssdetail' in debug.debug_flags:
 
646
 
                mutter('hpss body content: %s', body)
 
647
 
        self._write_args(args)
 
648
 
        bytes = self._encode_bulk_data(body)
 
649
 
        self._request.accept_bytes(bytes)
 
650
 
        self._request.finished_writing()
 
651
 
        self._last_verb = args[0]
 
653
 
    def call_with_body_readv_array(self, args, body):
 
654
 
        """Make a remote call with a readv array.
 
656
 
        The body is encoded with one line per readv offset pair. The numbers in
 
657
 
        each pair are separated by a comma, and no trailing \n is emitted.
 
659
 
        if 'hpss' in debug.debug_flags:
 
660
 
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
661
 
            if getattr(self._request._medium, '_path', None) is not None:
 
662
 
                mutter('                  (to %s)', self._request._medium._path)
 
663
 
            self._request_start_time = osutils.timer_func()
 
664
 
        self._write_args(args)
 
665
 
        readv_bytes = self._serialise_offsets(body)
 
666
 
        bytes = self._encode_bulk_data(readv_bytes)
 
667
 
        self._request.accept_bytes(bytes)
 
668
 
        self._request.finished_writing()
 
669
 
        if 'hpss' in debug.debug_flags:
 
670
 
            mutter('              %d bytes in readv request', len(readv_bytes))
 
671
 
        self._last_verb = args[0]
 
673
 
    def call_with_body_stream(self, args, stream):
 
674
 
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
675
 
        # assume that a v1/v2 server doesn't support whatever method we're
 
676
 
        # trying to call with a body stream.
 
677
 
        self._request.finished_writing()
 
678
 
        self._request.finished_reading()
 
679
 
        raise errors.UnknownSmartMethod(args[0])
 
681
 
    def cancel_read_body(self):
 
682
 
        """After expecting a body, a response code may indicate one otherwise.
 
684
 
        This method lets the domain client inform the protocol that no body
 
685
 
        will be transmitted. This is a terminal method: after calling it the
 
686
 
        protocol is not able to be used further.
 
688
 
        self._request.finished_reading()
 
690
 
    def _read_response_tuple(self):
 
691
 
        result = self._recv_tuple()
 
692
 
        if 'hpss' in debug.debug_flags:
 
693
 
            if self._request_start_time is not None:
 
694
 
                mutter('   result:   %6.3fs  %s',
 
695
 
                       osutils.timer_func() - self._request_start_time,
 
697
 
                self._request_start_time = None
 
699
 
                mutter('   result:   %s', repr(result)[1:-1])
 
702
 
    def read_response_tuple(self, expect_body=False):
 
703
 
        """Read a response tuple from the wire.
 
705
 
        This should only be called once.
 
707
 
        result = self._read_response_tuple()
 
708
 
        self._response_is_unknown_method(result)
 
709
 
        self._raise_args_if_error(result)
 
711
 
            self._request.finished_reading()
 
714
 
    def _raise_args_if_error(self, result_tuple):
 
715
 
        # Later protocol versions have an explicit flag in the protocol to say
 
716
 
        # if an error response is "failed" or not.  In version 1 we don't have
 
717
 
        # that luxury.  So here is a complete list of errors that can be
 
718
 
        # returned in response to existing version 1 smart requests.  Responses
 
719
 
        # starting with these codes are always "failed" responses.
 
726
 
            'UnicodeEncodeError',
 
727
 
            'UnicodeDecodeError',
 
733
 
            'UnlockableTransport',
 
739
 
        if result_tuple[0] in v1_error_codes:
 
740
 
            self._request.finished_reading()
 
741
 
            raise errors.ErrorFromSmartServer(result_tuple)
 
743
 
    def _response_is_unknown_method(self, result_tuple):
 
744
 
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
745
 
        method' response to the request.
 
747
 
        :param response: The response from a smart client call_expecting_body
 
749
 
        :param verb: The verb used in that call.
 
750
 
        :raises: UnexpectedSmartServerResponse
 
752
 
        if (result_tuple == ('error', "Generic bzr smart protocol error: "
 
753
 
                "bad request '%s'" % self._last_verb) or
 
754
 
              result_tuple == ('error', "Generic bzr smart protocol error: "
 
755
 
                "bad request u'%s'" % self._last_verb)):
 
756
 
            # The response will have no body, so we've finished reading.
 
757
 
            self._request.finished_reading()
 
758
 
            raise errors.UnknownSmartMethod(self._last_verb)
 
760
 
    def read_body_bytes(self, count=-1):
 
761
 
        """Read bytes from the body, decoding into a byte stream.
 
763
 
        We read all bytes at once to ensure we've checked the trailer for
 
764
 
        errors, and then feed the buffer back as read_body_bytes is called.
 
766
 
        if self._body_buffer is not None:
 
767
 
            return self._body_buffer.read(count)
 
768
 
        _body_decoder = LengthPrefixedBodyDecoder()
 
770
 
        while not _body_decoder.finished_reading:
 
771
 
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
773
 
                # end of file encountered reading from server
 
774
 
                raise errors.ConnectionReset(
 
775
 
                    "Connection lost while reading response body.")
 
776
 
            _body_decoder.accept_bytes(bytes)
 
777
 
        self._request.finished_reading()
 
778
 
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
779
 
        # XXX: TODO check the trailer result.
 
780
 
        if 'hpss' in debug.debug_flags:
 
781
 
            mutter('              %d body bytes read',
 
782
 
                   len(self._body_buffer.getvalue()))
 
783
 
        return self._body_buffer.read(count)
 
785
 
    def _recv_tuple(self):
 
786
 
        """Receive a tuple from the medium request."""
 
787
 
        return _decode_tuple(self._request.read_line())
 
789
 
    def query_version(self):
 
790
 
        """Return protocol version number of the server."""
 
792
 
        resp = self.read_response_tuple()
 
793
 
        if resp == ('ok', '1'):
 
795
 
        elif resp == ('ok', '2'):
 
798
 
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
800
 
    def _write_args(self, args):
 
801
 
        self._write_protocol_version()
 
802
 
        bytes = _encode_tuple(args)
 
803
 
        self._request.accept_bytes(bytes)
 
805
 
    def _write_protocol_version(self):
 
806
 
        """Write any prefixes this protocol requires.
 
808
 
        Version one doesn't send protocol versions.
 
812
 
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
813
 
    """Version two of the client side of the smart protocol.
 
815
 
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
818
 
    response_marker = RESPONSE_VERSION_TWO
 
819
 
    request_marker = REQUEST_VERSION_TWO
 
821
 
    def read_response_tuple(self, expect_body=False):
 
822
 
        """Read a response tuple from the wire.
 
824
 
        This should only be called once.
 
826
 
        version = self._request.read_line()
 
827
 
        if version != self.response_marker:
 
828
 
            self._request.finished_reading()
 
829
 
            raise errors.UnexpectedProtocolVersionMarker(version)
 
830
 
        response_status = self._request.read_line()
 
831
 
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
832
 
        self._response_is_unknown_method(result)
 
833
 
        if response_status == 'success\n':
 
834
 
            self.response_status = True
 
836
 
                self._request.finished_reading()
 
838
 
        elif response_status == 'failed\n':
 
839
 
            self.response_status = False
 
840
 
            self._request.finished_reading()
 
841
 
            raise errors.ErrorFromSmartServer(result)
 
843
 
            raise errors.SmartProtocolError(
 
844
 
                'bad protocol status %r' % response_status)
 
846
 
    def _write_protocol_version(self):
 
847
 
        """Write any prefixes this protocol requires.
 
849
 
        Version two sends the value of REQUEST_VERSION_TWO.
 
851
 
        self._request.accept_bytes(self.request_marker)
 
853
 
    def read_streamed_body(self):
 
854
 
        """Read bytes from the body, decoding into a byte stream.
 
856
 
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
857
 
        # buffer space available) on Windows.
 
858
 
        _body_decoder = ChunkedBodyDecoder()
 
859
 
        while not _body_decoder.finished_reading:
 
860
 
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
862
 
                # end of file encountered reading from server
 
863
 
                raise errors.ConnectionReset(
 
864
 
                    "Connection lost while reading streamed body.")
 
865
 
            _body_decoder.accept_bytes(bytes)
 
866
 
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
867
 
                if 'hpss' in debug.debug_flags and type(body_bytes) is str:
 
868
 
                    mutter('              %d byte chunk read',
 
871
 
        self._request.finished_reading()
 
874
 
def build_server_protocol_three(backing_transport, write_func,
 
875
 
                                root_client_path, jail_root=None):
 
876
 
    request_handler = request.SmartServerRequestHandler(
 
877
 
        backing_transport, commands=request.request_handlers,
 
878
 
        root_client_path=root_client_path, jail_root=jail_root)
 
879
 
    responder = ProtocolThreeResponder(write_func)
 
880
 
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
 
881
 
    return ProtocolThreeDecoder(message_handler)
 
884
 
class ProtocolThreeDecoder(_StatefulDecoder):
 
886
 
    response_marker = RESPONSE_VERSION_THREE
 
887
 
    request_marker = REQUEST_VERSION_THREE
 
889
 
    def __init__(self, message_handler, expect_version_marker=False):
 
890
 
        _StatefulDecoder.__init__(self)
 
891
 
        self._has_dispatched = False
 
893
 
        if expect_version_marker:
 
894
 
            self.state_accept = self._state_accept_expecting_protocol_version
 
895
 
            # We're expecting at least the protocol version marker + some
 
897
 
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
899
 
            self.state_accept = self._state_accept_expecting_headers
 
900
 
            self._number_needed_bytes = 4
 
901
 
        self.decoding_failed = False
 
902
 
        self.request_handler = self.message_handler = message_handler
 
904
 
    def accept_bytes(self, bytes):
 
905
 
        self._number_needed_bytes = None
 
907
 
            _StatefulDecoder.accept_bytes(self, bytes)
 
908
 
        except KeyboardInterrupt:
 
910
 
        except errors.SmartMessageHandlerError, exception:
 
911
 
            # We do *not* set self.decoding_failed here.  The message handler
 
912
 
            # has raised an error, but the decoder is still able to parse bytes
 
913
 
            # and determine when this message ends.
 
914
 
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
915
 
                log_exception_quietly()
 
916
 
            self.message_handler.protocol_error(exception.exc_value)
 
917
 
            # The state machine is ready to continue decoding, but the
 
918
 
            # exception has interrupted the loop that runs the state machine.
 
919
 
            # So we call accept_bytes again to restart it.
 
920
 
            self.accept_bytes('')
 
921
 
        except Exception, exception:
 
922
 
            # The decoder itself has raised an exception.  We cannot continue
 
924
 
            self.decoding_failed = True
 
925
 
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
926
 
                # This happens during normal operation when the client tries a
 
927
 
                # protocol version the server doesn't understand, so no need to
 
928
 
                # log a traceback every time.
 
929
 
                # Note that this can only happen when
 
930
 
                # expect_version_marker=True, which is only the case on the
 
934
 
                log_exception_quietly()
 
935
 
            self.message_handler.protocol_error(exception)
 
937
 
    def _extract_length_prefixed_bytes(self):
 
938
 
        if self._in_buffer_len < 4:
 
939
 
            # A length prefix by itself is 4 bytes, and we don't even have that
 
941
 
            raise _NeedMoreBytes(4)
 
942
 
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
943
 
        end_of_bytes = 4 + length
 
944
 
        if self._in_buffer_len < end_of_bytes:
 
945
 
            # We haven't yet read as many bytes as the length-prefix says there
 
947
 
            raise _NeedMoreBytes(end_of_bytes)
 
948
 
        # Extract the bytes from the buffer.
 
949
 
        in_buf = self._get_in_buffer()
 
950
 
        bytes = in_buf[4:end_of_bytes]
 
951
 
        self._set_in_buffer(in_buf[end_of_bytes:])
 
954
 
    def _extract_prefixed_bencoded_data(self):
 
955
 
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
957
 
            decoded = bdecode_as_tuple(prefixed_bytes)
 
959
 
            raise errors.SmartProtocolError(
 
960
 
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
963
 
    def _extract_single_byte(self):
 
964
 
        if self._in_buffer_len == 0:
 
965
 
            # The buffer is empty
 
966
 
            raise _NeedMoreBytes(1)
 
967
 
        in_buf = self._get_in_buffer()
 
969
 
        self._set_in_buffer(in_buf[1:])
 
972
 
    def _state_accept_expecting_protocol_version(self):
 
973
 
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
974
 
        in_buf = self._get_in_buffer()
 
976
 
            # We don't have enough bytes to check if the protocol version
 
977
 
            # marker is right.  But we can check if it is already wrong by
 
978
 
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
980
 
            # [In fact, if the remote end isn't bzr we might never receive
 
981
 
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
982
 
            # are wrong then we should just raise immediately rather than
 
984
 
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
985
 
                # We have enough bytes to know the protocol version is wrong
 
986
 
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
987
 
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
988
 
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
989
 
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
990
 
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
991
 
        self.state_accept = self._state_accept_expecting_headers
 
993
 
    def _state_accept_expecting_headers(self):
 
994
 
        decoded = self._extract_prefixed_bencoded_data()
 
995
 
        if type(decoded) is not dict:
 
996
 
            raise errors.SmartProtocolError(
 
997
 
                'Header object %r is not a dict' % (decoded,))
 
998
 
        self.state_accept = self._state_accept_expecting_message_part
 
1000
 
            self.message_handler.headers_received(decoded)
 
1002
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1004
 
    def _state_accept_expecting_message_part(self):
 
1005
 
        message_part_kind = self._extract_single_byte()
 
1006
 
        if message_part_kind == 'o':
 
1007
 
            self.state_accept = self._state_accept_expecting_one_byte
 
1008
 
        elif message_part_kind == 's':
 
1009
 
            self.state_accept = self._state_accept_expecting_structure
 
1010
 
        elif message_part_kind == 'b':
 
1011
 
            self.state_accept = self._state_accept_expecting_bytes
 
1012
 
        elif message_part_kind == 'e':
 
1015
 
            raise errors.SmartProtocolError(
 
1016
 
                'Bad message kind byte: %r' % (message_part_kind,))
 
1018
 
    def _state_accept_expecting_one_byte(self):
 
1019
 
        byte = self._extract_single_byte()
 
1020
 
        self.state_accept = self._state_accept_expecting_message_part
 
1022
 
            self.message_handler.byte_part_received(byte)
 
1024
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1026
 
    def _state_accept_expecting_bytes(self):
 
1027
 
        # XXX: this should not buffer whole message part, but instead deliver
 
1028
 
        # the bytes as they arrive.
 
1029
 
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1030
 
        self.state_accept = self._state_accept_expecting_message_part
 
1032
 
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1034
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1036
 
    def _state_accept_expecting_structure(self):
 
1037
 
        structure = self._extract_prefixed_bencoded_data()
 
1038
 
        self.state_accept = self._state_accept_expecting_message_part
 
1040
 
            self.message_handler.structure_part_received(structure)
 
1042
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1045
 
        self.unused_data = self._get_in_buffer()
 
1046
 
        self._set_in_buffer(None)
 
1047
 
        self.state_accept = self._state_accept_reading_unused
 
1049
 
            self.message_handler.end_received()
 
1051
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1053
 
    def _state_accept_reading_unused(self):
 
1054
 
        self.unused_data += self._get_in_buffer()
 
1055
 
        self._set_in_buffer(None)
 
1057
 
    def next_read_size(self):
 
1058
 
        if self.state_accept == self._state_accept_reading_unused:
 
1060
 
        elif self.decoding_failed:
 
1061
 
            # An exception occured while processing this message, probably from
 
1062
 
            # self.message_handler.  We're not sure that this state machine is
 
1063
 
            # in a consistent state, so just signal that we're done (i.e. give
 
1067
 
            if self._number_needed_bytes is not None:
 
1068
 
                return self._number_needed_bytes - self._in_buffer_len
 
1070
 
                raise AssertionError("don't know how many bytes are expected!")
 
1073
 
class _ProtocolThreeEncoder(object):
 
1075
 
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1076
 
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
 
1078
 
    def __init__(self, write_func):
 
1081
 
        self._real_write_func = write_func
 
1083
 
    def _write_func(self, bytes):
 
1084
 
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
 
1085
 
        #       for total number of bytes to write, rather than buffer based on
 
1086
 
        #       the number of write() calls
 
1087
 
        # TODO: Another possibility would be to turn this into an async model.
 
1088
 
        #       Where we let another thread know that we have some bytes if
 
1089
 
        #       they want it, but we don't actually block for it
 
1090
 
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1091
 
        #       we might just push out smaller bits at a time?
 
1092
 
        self._buf.append(bytes)
 
1093
 
        self._buf_len += len(bytes)
 
1094
 
        if self._buf_len > self.BUFFER_SIZE:
 
1099
 
            self._real_write_func(''.join(self._buf))
 
1103
 
    def _serialise_offsets(self, offsets):
 
1104
 
        """Serialise a readv offset list."""
 
1106
 
        for start, length in offsets:
 
1107
 
            txt.append('%d,%d' % (start, length))
 
1108
 
        return '\n'.join(txt)
 
1110
 
    def _write_protocol_version(self):
 
1111
 
        self._write_func(MESSAGE_VERSION_THREE)
 
1113
 
    def _write_prefixed_bencode(self, structure):
 
1114
 
        bytes = bencode(structure)
 
1115
 
        self._write_func(struct.pack('!L', len(bytes)))
 
1116
 
        self._write_func(bytes)
 
1118
 
    def _write_headers(self, headers):
 
1119
 
        self._write_prefixed_bencode(headers)
 
1121
 
    def _write_structure(self, args):
 
1122
 
        self._write_func('s')
 
1125
 
            if type(arg) is unicode:
 
1126
 
                utf8_args.append(arg.encode('utf8'))
 
1128
 
                utf8_args.append(arg)
 
1129
 
        self._write_prefixed_bencode(utf8_args)
 
1131
 
    def _write_end(self):
 
1132
 
        self._write_func('e')
 
1135
 
    def _write_prefixed_body(self, bytes):
 
1136
 
        self._write_func('b')
 
1137
 
        self._write_func(struct.pack('!L', len(bytes)))
 
1138
 
        self._write_func(bytes)
 
1140
 
    def _write_chunked_body_start(self):
 
1141
 
        self._write_func('oC')
 
1143
 
    def _write_error_status(self):
 
1144
 
        self._write_func('oE')
 
1146
 
    def _write_success_status(self):
 
1147
 
        self._write_func('oS')
 
1150
 
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1152
 
    def __init__(self, write_func):
 
1153
 
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1154
 
        self.response_sent = False
 
1155
 
        self._headers = {'Software version': bzrlib.__version__}
 
1156
 
        if 'hpss' in debug.debug_flags:
 
1157
 
            self._thread_id = thread.get_ident()
 
1158
 
            self._response_start_time = None
 
1160
 
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1161
 
        if self._response_start_time is None:
 
1162
 
            self._response_start_time = osutils.timer_func()
 
1164
 
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1167
 
        if extra_bytes is None:
 
1170
 
            extra = ' ' + repr(extra_bytes[:40])
 
1172
 
                extra = extra[:29] + extra[-1] + '...'
 
1173
 
        mutter('%12s: [%s] %s%s%s'
 
1174
 
               % (action, self._thread_id, t, message, extra))
 
1176
 
    def send_error(self, exception):
 
1177
 
        if self.response_sent:
 
1178
 
            raise AssertionError(
 
1179
 
                "send_error(%s) called, but response already sent."
 
1181
 
        if isinstance(exception, errors.UnknownSmartMethod):
 
1182
 
            failure = request.FailedSmartServerResponse(
 
1183
 
                ('UnknownMethod', exception.verb))
 
1184
 
            self.send_response(failure)
 
1186
 
        if 'hpss' in debug.debug_flags:
 
1187
 
            self._trace('error', str(exception))
 
1188
 
        self.response_sent = True
 
1189
 
        self._write_protocol_version()
 
1190
 
        self._write_headers(self._headers)
 
1191
 
        self._write_error_status()
 
1192
 
        self._write_structure(('error', str(exception)))
 
1195
 
    def send_response(self, response):
 
1196
 
        if self.response_sent:
 
1197
 
            raise AssertionError(
 
1198
 
                "send_response(%r) called, but response already sent."
 
1200
 
        self.response_sent = True
 
1201
 
        self._write_protocol_version()
 
1202
 
        self._write_headers(self._headers)
 
1203
 
        if response.is_successful():
 
1204
 
            self._write_success_status()
 
1206
 
            self._write_error_status()
 
1207
 
        if 'hpss' in debug.debug_flags:
 
1208
 
            self._trace('response', repr(response.args))
 
1209
 
        self._write_structure(response.args)
 
1210
 
        if response.body is not None:
 
1211
 
            self._write_prefixed_body(response.body)
 
1212
 
            if 'hpss' in debug.debug_flags:
 
1213
 
                self._trace('body', '%d bytes' % (len(response.body),),
 
1214
 
                            response.body, include_time=True)
 
1215
 
        elif response.body_stream is not None:
 
1216
 
            count = num_bytes = 0
 
1218
 
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1220
 
                if exc_info is not None:
 
1221
 
                    self._write_error_status()
 
1222
 
                    error_struct = request._translate_error(exc_info[1])
 
1223
 
                    self._write_structure(error_struct)
 
1226
 
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1227
 
                        self._write_error_status()
 
1228
 
                        self._write_structure(chunk.args)
 
1230
 
                    num_bytes += len(chunk)
 
1231
 
                    if first_chunk is None:
 
1233
 
                    self._write_prefixed_body(chunk)
 
1234
 
                    if 'hpssdetail' in debug.debug_flags:
 
1235
 
                        # Not worth timing separately, as _write_func is
 
1237
 
                        self._trace('body chunk',
 
1238
 
                                    '%d bytes' % (len(chunk),),
 
1239
 
                                    chunk, suppress_time=True)
 
1240
 
            if 'hpss' in debug.debug_flags:
 
1241
 
                self._trace('body stream',
 
1242
 
                            '%d bytes %d chunks' % (num_bytes, count),
 
1245
 
        if 'hpss' in debug.debug_flags:
 
1246
 
            self._trace('response end', '', include_time=True)
 
1249
 
def _iter_with_errors(iterable):
 
1250
 
    """Handle errors from iterable.next().
 
1254
 
        for exc_info, value in _iter_with_errors(iterable):
 
1257
 
    This is a safer alternative to::
 
1260
 
            for value in iterable:
 
1265
 
    Because the latter will catch errors from the for-loop body, not just
 
1268
 
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1269
 
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1270
 
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1271
 
    will not be itercepted.
 
1273
 
    iterator = iter(iterable)
 
1276
 
            yield None, iterator.next()
 
1277
 
        except StopIteration:
 
1279
 
        except (KeyboardInterrupt, SystemExit):
 
1282
 
            mutter('_iter_with_errors caught error')
 
1283
 
            log_exception_quietly()
 
1284
 
            yield sys.exc_info(), None
 
1288
 
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1290
 
    def __init__(self, medium_request):
 
1291
 
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1292
 
        self._medium_request = medium_request
 
1295
 
    def set_headers(self, headers):
 
1296
 
        self._headers = headers.copy()
 
1298
 
    def call(self, *args):
 
1299
 
        if 'hpss' in debug.debug_flags:
 
1300
 
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1301
 
            base = getattr(self._medium_request._medium, 'base', None)
 
1302
 
            if base is not None:
 
1303
 
                mutter('             (to %s)', base)
 
1304
 
            self._request_start_time = osutils.timer_func()
 
1305
 
        self._write_protocol_version()
 
1306
 
        self._write_headers(self._headers)
 
1307
 
        self._write_structure(args)
 
1309
 
        self._medium_request.finished_writing()
 
1311
 
    def call_with_body_bytes(self, args, body):
 
1312
 
        """Make a remote call of args with body bytes 'body'.
 
1314
 
        After calling this, call read_response_tuple to find the result out.
 
1316
 
        if 'hpss' in debug.debug_flags:
 
1317
 
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1318
 
            path = getattr(self._medium_request._medium, '_path', None)
 
1319
 
            if path is not None:
 
1320
 
                mutter('                  (to %s)', path)
 
1321
 
            mutter('              %d bytes', len(body))
 
1322
 
            self._request_start_time = osutils.timer_func()
 
1323
 
        self._write_protocol_version()
 
1324
 
        self._write_headers(self._headers)
 
1325
 
        self._write_structure(args)
 
1326
 
        self._write_prefixed_body(body)
 
1328
 
        self._medium_request.finished_writing()
 
1330
 
    def call_with_body_readv_array(self, args, body):
 
1331
 
        """Make a remote call with a readv array.
 
1333
 
        The body is encoded with one line per readv offset pair. The numbers in
 
1334
 
        each pair are separated by a comma, and no trailing \n is emitted.
 
1336
 
        if 'hpss' in debug.debug_flags:
 
1337
 
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1338
 
            path = getattr(self._medium_request._medium, '_path', None)
 
1339
 
            if path is not None:
 
1340
 
                mutter('                  (to %s)', path)
 
1341
 
            self._request_start_time = osutils.timer_func()
 
1342
 
        self._write_protocol_version()
 
1343
 
        self._write_headers(self._headers)
 
1344
 
        self._write_structure(args)
 
1345
 
        readv_bytes = self._serialise_offsets(body)
 
1346
 
        if 'hpss' in debug.debug_flags:
 
1347
 
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1348
 
        self._write_prefixed_body(readv_bytes)
 
1350
 
        self._medium_request.finished_writing()
 
1352
 
    def call_with_body_stream(self, args, stream):
 
1353
 
        if 'hpss' in debug.debug_flags:
 
1354
 
            mutter('hpss call w/body stream: %r', args)
 
1355
 
            path = getattr(self._medium_request._medium, '_path', None)
 
1356
 
            if path is not None:
 
1357
 
                mutter('                  (to %s)', path)
 
1358
 
            self._request_start_time = osutils.timer_func()
 
1359
 
        self._write_protocol_version()
 
1360
 
        self._write_headers(self._headers)
 
1361
 
        self._write_structure(args)
 
1362
 
        # TODO: notice if the server has sent an early error reply before we
 
1363
 
        #       have finished sending the stream.  We would notice at the end
 
1364
 
        #       anyway, but if the medium can deliver it early then it's good
 
1365
 
        #       to short-circuit the whole request...
 
1366
 
        for exc_info, part in _iter_with_errors(stream):
 
1367
 
            if exc_info is not None:
 
1368
 
                # Iterating the stream failed.  Cleanly abort the request.
 
1369
 
                self._write_error_status()
 
1370
 
                # Currently the client unconditionally sends ('error',) as the
 
1372
 
                self._write_structure(('error',))
 
1374
 
                self._medium_request.finished_writing()
 
1375
 
                raise exc_info[0], exc_info[1], exc_info[2]
 
1377
 
                self._write_prefixed_body(part)
 
1380
 
        self._medium_request.finished_writing()