/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/smart/protocol.py

  • Committer: Jelmer Vernooij
  • Date: 2017-06-10 16:40:42 UTC
  • mfrom: (6653.6.7 rename-controldir)
  • mto: This revision was merged to the branch mainline in revision 6690.
  • Revision ID: jelmer@jelmer.uk-20170610164042-zrxqgy2htyduvke2
MergeĀ rename-controldirĀ branch.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Wire-level encoding and decoding of requests and responses for the smart
 
18
client and server.
 
19
"""
 
20
 
 
21
from __future__ import absolute_import
 
22
 
 
23
import collections
 
24
import struct
 
25
import sys
 
26
try:
 
27
    import _thread
 
28
except ImportError:
 
29
    import thread as _thread
 
30
import time
 
31
 
 
32
import breezy
 
33
from .. import (
 
34
    debug,
 
35
    errors,
 
36
    osutils,
 
37
    )
 
38
from ..sixish import (
 
39
    BytesIO,
 
40
    reraise,
 
41
)
 
42
from . import message, request
 
43
from ..trace import log_exception_quietly, mutter
 
44
from ..bencode import bdecode_as_tuple, bencode
 
45
 
 
46
 
 
47
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
48
# responses to identify the protocol version being used. (There are no version
 
49
# one strings because that version doesn't send any).
 
50
REQUEST_VERSION_TWO = 'bzr request 2\n'
 
51
RESPONSE_VERSION_TWO = 'bzr response 2\n'
 
52
 
 
53
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
 
54
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
 
55
 
 
56
 
 
57
def _recv_tuple(from_file):
 
58
    req_line = from_file.readline()
 
59
    return _decode_tuple(req_line)
 
60
 
 
61
 
 
62
def _decode_tuple(req_line):
 
63
    if req_line is None or req_line == '':
 
64
        return None
 
65
    if req_line[-1] != '\n':
 
66
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
67
    return tuple(req_line[:-1].split('\x01'))
 
68
 
 
69
 
 
70
def _encode_tuple(args):
 
71
    """Encode the tuple args to a bytestream."""
 
72
    joined = '\x01'.join(args) + '\n'
 
73
    if isinstance(joined, unicode):
 
74
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
75
        mutter('response args contain unicode, should be only bytes: %r',
 
76
               joined)
 
77
        joined = joined.encode('ascii')
 
78
    return joined
 
79
 
 
80
 
 
81
class Requester(object):
 
82
    """Abstract base class for an object that can issue requests on a smart
 
83
    medium.
 
84
    """
 
85
 
 
86
    def call(self, *args):
 
87
        """Make a remote call.
 
88
 
 
89
        :param args: the arguments of this call.
 
90
        """
 
91
        raise NotImplementedError(self.call)
 
92
 
 
93
    def call_with_body_bytes(self, args, body):
 
94
        """Make a remote call with a body.
 
95
 
 
96
        :param args: the arguments of this call.
 
97
        :type body: str
 
98
        :param body: the body to send with the request.
 
99
        """
 
100
        raise NotImplementedError(self.call_with_body_bytes)
 
101
 
 
102
    def call_with_body_readv_array(self, args, body):
 
103
        """Make a remote call with a readv array.
 
104
 
 
105
        :param args: the arguments of this call.
 
106
        :type body: iterable of (start, length) tuples.
 
107
        :param body: the readv ranges to send with this request.
 
108
        """
 
109
        raise NotImplementedError(self.call_with_body_readv_array)
 
110
 
 
111
    def set_headers(self, headers):
 
112
        raise NotImplementedError(self.set_headers)
 
113
 
 
114
 
 
115
class SmartProtocolBase(object):
 
116
    """Methods common to client and server"""
 
117
 
 
118
    # TODO: this only actually accomodates a single block; possibly should
 
119
    # support multiple chunks?
 
120
    def _encode_bulk_data(self, body):
 
121
        """Encode body as a bulk data chunk."""
 
122
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
123
 
 
124
    def _serialise_offsets(self, offsets):
 
125
        """Serialise a readv offset list."""
 
126
        txt = []
 
127
        for start, length in offsets:
 
128
            txt.append('%d,%d' % (start, length))
 
129
        return '\n'.join(txt)
 
130
 
 
131
 
 
132
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
133
    """Server-side encoding and decoding logic for smart version 1."""
 
134
 
 
135
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
136
            jail_root=None):
 
137
        self._backing_transport = backing_transport
 
138
        self._root_client_path = root_client_path
 
139
        self._jail_root = jail_root
 
140
        self.unused_data = ''
 
141
        self._finished = False
 
142
        self.in_buffer = ''
 
143
        self._has_dispatched = False
 
144
        self.request = None
 
145
        self._body_decoder = None
 
146
        self._write_func = write_func
 
147
 
 
148
    def accept_bytes(self, bytes):
 
149
        """Take bytes, and advance the internal state machine appropriately.
 
150
 
 
151
        :param bytes: must be a byte string
 
152
        """
 
153
        if not isinstance(bytes, str):
 
154
            raise ValueError(bytes)
 
155
        self.in_buffer += bytes
 
156
        if not self._has_dispatched:
 
157
            if '\n' not in self.in_buffer:
 
158
                # no command line yet
 
159
                return
 
160
            self._has_dispatched = True
 
161
            try:
 
162
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
163
                first_line += '\n'
 
164
                req_args = _decode_tuple(first_line)
 
165
                self.request = request.SmartServerRequestHandler(
 
166
                    self._backing_transport, commands=request.request_handlers,
 
167
                    root_client_path=self._root_client_path,
 
168
                    jail_root=self._jail_root)
 
169
                self.request.args_received(req_args)
 
170
                if self.request.finished_reading:
 
171
                    # trivial request
 
172
                    self.unused_data = self.in_buffer
 
173
                    self.in_buffer = ''
 
174
                    self._send_response(self.request.response)
 
175
            except KeyboardInterrupt:
 
176
                raise
 
177
            except errors.UnknownSmartMethod as err:
 
178
                protocol_error = errors.SmartProtocolError(
 
179
                    "bad request %r" % (err.verb,))
 
180
                failure = request.FailedSmartServerResponse(
 
181
                    ('error', str(protocol_error)))
 
182
                self._send_response(failure)
 
183
                return
 
184
            except Exception as exception:
 
185
                # everything else: pass to client, flush, and quit
 
186
                log_exception_quietly()
 
187
                self._send_response(request.FailedSmartServerResponse(
 
188
                    ('error', str(exception))))
 
189
                return
 
190
 
 
191
        if self._has_dispatched:
 
192
            if self._finished:
 
193
                # nothing to do.XXX: this routine should be a single state
 
194
                # machine too.
 
195
                self.unused_data += self.in_buffer
 
196
                self.in_buffer = ''
 
197
                return
 
198
            if self._body_decoder is None:
 
199
                self._body_decoder = LengthPrefixedBodyDecoder()
 
200
            self._body_decoder.accept_bytes(self.in_buffer)
 
201
            self.in_buffer = self._body_decoder.unused_data
 
202
            body_data = self._body_decoder.read_pending_data()
 
203
            self.request.accept_body(body_data)
 
204
            if self._body_decoder.finished_reading:
 
205
                self.request.end_of_body()
 
206
                if not self.request.finished_reading:
 
207
                    raise AssertionError("no more body, request not finished")
 
208
            if self.request.response is not None:
 
209
                self._send_response(self.request.response)
 
210
                self.unused_data = self.in_buffer
 
211
                self.in_buffer = ''
 
212
            else:
 
213
                if self.request.finished_reading:
 
214
                    raise AssertionError(
 
215
                        "no response and we have finished reading.")
 
216
 
 
217
    def _send_response(self, response):
 
218
        """Send a smart server response down the output stream."""
 
219
        if self._finished:
 
220
            raise AssertionError('response already sent')
 
221
        args = response.args
 
222
        body = response.body
 
223
        self._finished = True
 
224
        self._write_protocol_version()
 
225
        self._write_success_or_failure_prefix(response)
 
226
        self._write_func(_encode_tuple(args))
 
227
        if body is not None:
 
228
            if not isinstance(body, str):
 
229
                raise ValueError(body)
 
230
            bytes = self._encode_bulk_data(body)
 
231
            self._write_func(bytes)
 
232
 
 
233
    def _write_protocol_version(self):
 
234
        """Write any prefixes this protocol requires.
 
235
 
 
236
        Version one doesn't send protocol versions.
 
237
        """
 
238
 
 
239
    def _write_success_or_failure_prefix(self, response):
 
240
        """Write the protocol specific success/failure prefix.
 
241
 
 
242
        For SmartServerRequestProtocolOne this is omitted but we
 
243
        call is_successful to ensure that the response is valid.
 
244
        """
 
245
        response.is_successful()
 
246
 
 
247
    def next_read_size(self):
 
248
        if self._finished:
 
249
            return 0
 
250
        if self._body_decoder is None:
 
251
            return 1
 
252
        else:
 
253
            return self._body_decoder.next_read_size()
 
254
 
 
255
 
 
256
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
257
    r"""Version two of the server side of the smart protocol.
 
258
 
 
259
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
260
    """
 
261
 
 
262
    response_marker = RESPONSE_VERSION_TWO
 
263
    request_marker = REQUEST_VERSION_TWO
 
264
 
 
265
    def _write_success_or_failure_prefix(self, response):
 
266
        """Write the protocol specific success/failure prefix."""
 
267
        if response.is_successful():
 
268
            self._write_func('success\n')
 
269
        else:
 
270
            self._write_func('failed\n')
 
271
 
 
272
    def _write_protocol_version(self):
 
273
        r"""Write any prefixes this protocol requires.
 
274
 
 
275
        Version two sends the value of RESPONSE_VERSION_TWO.
 
276
        """
 
277
        self._write_func(self.response_marker)
 
278
 
 
279
    def _send_response(self, response):
 
280
        """Send a smart server response down the output stream."""
 
281
        if (self._finished):
 
282
            raise AssertionError('response already sent')
 
283
        self._finished = True
 
284
        self._write_protocol_version()
 
285
        self._write_success_or_failure_prefix(response)
 
286
        self._write_func(_encode_tuple(response.args))
 
287
        if response.body is not None:
 
288
            if not isinstance(response.body, str):
 
289
                raise AssertionError('body must be a str')
 
290
            if not (response.body_stream is None):
 
291
                raise AssertionError(
 
292
                    'body_stream and body cannot both be set')
 
293
            bytes = self._encode_bulk_data(response.body)
 
294
            self._write_func(bytes)
 
295
        elif response.body_stream is not None:
 
296
            _send_stream(response.body_stream, self._write_func)
 
297
 
 
298
 
 
299
def _send_stream(stream, write_func):
 
300
    write_func('chunked\n')
 
301
    _send_chunks(stream, write_func)
 
302
    write_func('END\n')
 
303
 
 
304
 
 
305
def _send_chunks(stream, write_func):
 
306
    for chunk in stream:
 
307
        if isinstance(chunk, str):
 
308
            bytes = "%x\n%s" % (len(chunk), chunk)
 
309
            write_func(bytes)
 
310
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
311
            write_func('ERR\n')
 
312
            _send_chunks(chunk.args, write_func)
 
313
            return
 
314
        else:
 
315
            raise errors.BzrError(
 
316
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
317
                % chunk)
 
318
 
 
319
 
 
320
class _NeedMoreBytes(Exception):
 
321
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
322
    have been received.
 
323
    """
 
324
 
 
325
    def __init__(self, count=None):
 
326
        """Constructor.
 
327
 
 
328
        :param count: the total number of bytes needed by the current state.
 
329
            May be None if the number of bytes needed is unknown.
 
330
        """
 
331
        self.count = count
 
332
 
 
333
 
 
334
class _StatefulDecoder(object):
 
335
    """Base class for writing state machines to decode byte streams.
 
336
 
 
337
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
338
    and, if appropriate, updates self.state_accept to a different function.
 
339
    accept_bytes will call state_accept as often as necessary to make sure the
 
340
    state machine has progressed as far as possible before it returns.
 
341
 
 
342
    See ProtocolThreeDecoder for an example subclass.
 
343
    """
 
344
 
 
345
    def __init__(self):
 
346
        self.finished_reading = False
 
347
        self._in_buffer_list = []
 
348
        self._in_buffer_len = 0
 
349
        self.unused_data = ''
 
350
        self.bytes_left = None
 
351
        self._number_needed_bytes = None
 
352
 
 
353
    def _get_in_buffer(self):
 
354
        if len(self._in_buffer_list) == 1:
 
355
            return self._in_buffer_list[0]
 
356
        in_buffer = ''.join(self._in_buffer_list)
 
357
        if len(in_buffer) != self._in_buffer_len:
 
358
            raise AssertionError(
 
359
                "Length of buffer did not match expected value: %s != %s"
 
360
                % self._in_buffer_len, len(in_buffer))
 
361
        self._in_buffer_list = [in_buffer]
 
362
        return in_buffer
 
363
 
 
364
    def _get_in_bytes(self, count):
 
365
        """Grab X bytes from the input_buffer.
 
366
 
 
367
        Callers should have already checked that self._in_buffer_len is >
 
368
        count. Note, this does not consume the bytes from the buffer. The
 
369
        caller will still need to call _get_in_buffer() and then
 
370
        _set_in_buffer() if they actually need to consume the bytes.
 
371
        """
 
372
        # check if we can yield the bytes from just the first entry in our list
 
373
        if len(self._in_buffer_list) == 0:
 
374
            raise AssertionError('Callers must be sure we have buffered bytes'
 
375
                ' before calling _get_in_bytes')
 
376
        if len(self._in_buffer_list[0]) > count:
 
377
            return self._in_buffer_list[0][:count]
 
378
        # We can't yield it from the first buffer, so collapse all buffers, and
 
379
        # yield it from that
 
380
        in_buf = self._get_in_buffer()
 
381
        return in_buf[:count]
 
382
 
 
383
    def _set_in_buffer(self, new_buf):
 
384
        if new_buf is not None:
 
385
            self._in_buffer_list = [new_buf]
 
386
            self._in_buffer_len = len(new_buf)
 
387
        else:
 
388
            self._in_buffer_list = []
 
389
            self._in_buffer_len = 0
 
390
 
 
391
    def accept_bytes(self, bytes):
 
392
        """Decode as much of bytes as possible.
 
393
 
 
394
        If 'bytes' contains too much data it will be appended to
 
395
        self.unused_data.
 
396
 
 
397
        finished_reading will be set when no more data is required.  Further
 
398
        data will be appended to self.unused_data.
 
399
        """
 
400
        # accept_bytes is allowed to change the state
 
401
        self._number_needed_bytes = None
 
402
        # lsprof puts a very large amount of time on this specific call for
 
403
        # large readv arrays
 
404
        self._in_buffer_list.append(bytes)
 
405
        self._in_buffer_len += len(bytes)
 
406
        try:
 
407
            # Run the function for the current state.
 
408
            current_state = self.state_accept
 
409
            self.state_accept()
 
410
            while current_state != self.state_accept:
 
411
                # The current state has changed.  Run the function for the new
 
412
                # current state, so that it can:
 
413
                #   - decode any unconsumed bytes left in a buffer, and
 
414
                #   - signal how many more bytes are expected (via raising
 
415
                #     _NeedMoreBytes).
 
416
                current_state = self.state_accept
 
417
                self.state_accept()
 
418
        except _NeedMoreBytes as e:
 
419
            self._number_needed_bytes = e.count
 
420
 
 
421
 
 
422
class ChunkedBodyDecoder(_StatefulDecoder):
 
423
    """Decoder for chunked body data.
 
424
 
 
425
    This is very similar the HTTP's chunked encoding.  See the description of
 
426
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
427
    """
 
428
 
 
429
    def __init__(self):
 
430
        _StatefulDecoder.__init__(self)
 
431
        self.state_accept = self._state_accept_expecting_header
 
432
        self.chunk_in_progress = None
 
433
        self.chunks = collections.deque()
 
434
        self.error = False
 
435
        self.error_in_progress = None
 
436
 
 
437
    def next_read_size(self):
 
438
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
439
        # end-of-body marker is 4 bytes: 'END\n'.
 
440
        if self.state_accept == self._state_accept_reading_chunk:
 
441
            # We're expecting more chunk content.  So we're expecting at least
 
442
            # the rest of this chunk plus an END chunk.
 
443
            return self.bytes_left + 4
 
444
        elif self.state_accept == self._state_accept_expecting_length:
 
445
            if self._in_buffer_len == 0:
 
446
                # We're expecting a chunk length.  There's at least two bytes
 
447
                # left: a digit plus '\n'.
 
448
                return 2
 
449
            else:
 
450
                # We're in the middle of reading a chunk length.  So there's at
 
451
                # least one byte left, the '\n' that terminates the length.
 
452
                return 1
 
453
        elif self.state_accept == self._state_accept_reading_unused:
 
454
            return 1
 
455
        elif self.state_accept == self._state_accept_expecting_header:
 
456
            return max(0, len('chunked\n') - self._in_buffer_len)
 
457
        else:
 
458
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
459
 
 
460
    def read_next_chunk(self):
 
461
        try:
 
462
            return self.chunks.popleft()
 
463
        except IndexError:
 
464
            return None
 
465
 
 
466
    def _extract_line(self):
 
467
        in_buf = self._get_in_buffer()
 
468
        pos = in_buf.find('\n')
 
469
        if pos == -1:
 
470
            # We haven't read a complete line yet, so request more bytes before
 
471
            # we continue.
 
472
            raise _NeedMoreBytes(1)
 
473
        line = in_buf[:pos]
 
474
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
475
        self._set_in_buffer(in_buf[pos+1:])
 
476
        return line
 
477
 
 
478
    def _finished(self):
 
479
        self.unused_data = self._get_in_buffer()
 
480
        self._in_buffer_list = []
 
481
        self._in_buffer_len = 0
 
482
        self.state_accept = self._state_accept_reading_unused
 
483
        if self.error:
 
484
            error_args = tuple(self.error_in_progress)
 
485
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
486
            self.error_in_progress = None
 
487
        self.finished_reading = True
 
488
 
 
489
    def _state_accept_expecting_header(self):
 
490
        prefix = self._extract_line()
 
491
        if prefix == 'chunked':
 
492
            self.state_accept = self._state_accept_expecting_length
 
493
        else:
 
494
            raise errors.SmartProtocolError(
 
495
                'Bad chunked body header: "%s"' % (prefix,))
 
496
 
 
497
    def _state_accept_expecting_length(self):
 
498
        prefix = self._extract_line()
 
499
        if prefix == 'ERR':
 
500
            self.error = True
 
501
            self.error_in_progress = []
 
502
            self._state_accept_expecting_length()
 
503
            return
 
504
        elif prefix == 'END':
 
505
            # We've read the end-of-body marker.
 
506
            # Any further bytes are unused data, including the bytes left in
 
507
            # the _in_buffer.
 
508
            self._finished()
 
509
            return
 
510
        else:
 
511
            self.bytes_left = int(prefix, 16)
 
512
            self.chunk_in_progress = ''
 
513
            self.state_accept = self._state_accept_reading_chunk
 
514
 
 
515
    def _state_accept_reading_chunk(self):
 
516
        in_buf = self._get_in_buffer()
 
517
        in_buffer_len = len(in_buf)
 
518
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
519
        self._set_in_buffer(in_buf[self.bytes_left:])
 
520
        self.bytes_left -= in_buffer_len
 
521
        if self.bytes_left <= 0:
 
522
            # Finished with chunk
 
523
            self.bytes_left = None
 
524
            if self.error:
 
525
                self.error_in_progress.append(self.chunk_in_progress)
 
526
            else:
 
527
                self.chunks.append(self.chunk_in_progress)
 
528
            self.chunk_in_progress = None
 
529
            self.state_accept = self._state_accept_expecting_length
 
530
 
 
531
    def _state_accept_reading_unused(self):
 
532
        self.unused_data += self._get_in_buffer()
 
533
        self._in_buffer_list = []
 
534
 
 
535
 
 
536
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
537
    """Decodes the length-prefixed bulk data."""
 
538
 
 
539
    def __init__(self):
 
540
        _StatefulDecoder.__init__(self)
 
541
        self.state_accept = self._state_accept_expecting_length
 
542
        self.state_read = self._state_read_no_data
 
543
        self._body = ''
 
544
        self._trailer_buffer = ''
 
545
 
 
546
    def next_read_size(self):
 
547
        if self.bytes_left is not None:
 
548
            # Ideally we want to read all the remainder of the body and the
 
549
            # trailer in one go.
 
550
            return self.bytes_left + 5
 
551
        elif self.state_accept == self._state_accept_reading_trailer:
 
552
            # Just the trailer left
 
553
            return 5 - len(self._trailer_buffer)
 
554
        elif self.state_accept == self._state_accept_expecting_length:
 
555
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
556
            # 'done\n').
 
557
            return 6
 
558
        else:
 
559
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
560
            return 1
 
561
 
 
562
    def read_pending_data(self):
 
563
        """Return any pending data that has been decoded."""
 
564
        return self.state_read()
 
565
 
 
566
    def _state_accept_expecting_length(self):
 
567
        in_buf = self._get_in_buffer()
 
568
        pos = in_buf.find('\n')
 
569
        if pos == -1:
 
570
            return
 
571
        self.bytes_left = int(in_buf[:pos])
 
572
        self._set_in_buffer(in_buf[pos+1:])
 
573
        self.state_accept = self._state_accept_reading_body
 
574
        self.state_read = self._state_read_body_buffer
 
575
 
 
576
    def _state_accept_reading_body(self):
 
577
        in_buf = self._get_in_buffer()
 
578
        self._body += in_buf
 
579
        self.bytes_left -= len(in_buf)
 
580
        self._set_in_buffer(None)
 
581
        if self.bytes_left <= 0:
 
582
            # Finished with body
 
583
            if self.bytes_left != 0:
 
584
                self._trailer_buffer = self._body[self.bytes_left:]
 
585
                self._body = self._body[:self.bytes_left]
 
586
            self.bytes_left = None
 
587
            self.state_accept = self._state_accept_reading_trailer
 
588
 
 
589
    def _state_accept_reading_trailer(self):
 
590
        self._trailer_buffer += self._get_in_buffer()
 
591
        self._set_in_buffer(None)
 
592
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
593
        # a ProtocolViolation exception?
 
594
        if self._trailer_buffer.startswith('done\n'):
 
595
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
596
            self.state_accept = self._state_accept_reading_unused
 
597
            self.finished_reading = True
 
598
 
 
599
    def _state_accept_reading_unused(self):
 
600
        self.unused_data += self._get_in_buffer()
 
601
        self._set_in_buffer(None)
 
602
 
 
603
    def _state_read_no_data(self):
 
604
        return ''
 
605
 
 
606
    def _state_read_body_buffer(self):
 
607
        result = self._body
 
608
        self._body = ''
 
609
        return result
 
610
 
 
611
 
 
612
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
613
                                    message.ResponseHandler):
 
614
    """The client-side protocol for smart version 1."""
 
615
 
 
616
    def __init__(self, request):
 
617
        """Construct a SmartClientRequestProtocolOne.
 
618
 
 
619
        :param request: A SmartClientMediumRequest to serialise onto and
 
620
            deserialise from.
 
621
        """
 
622
        self._request = request
 
623
        self._body_buffer = None
 
624
        self._request_start_time = None
 
625
        self._last_verb = None
 
626
        self._headers = None
 
627
 
 
628
    def set_headers(self, headers):
 
629
        self._headers = dict(headers)
 
630
 
 
631
    def call(self, *args):
 
632
        if 'hpss' in debug.debug_flags:
 
633
            mutter('hpss call:   %s', repr(args)[1:-1])
 
634
            if getattr(self._request._medium, 'base', None) is not None:
 
635
                mutter('             (to %s)', self._request._medium.base)
 
636
            self._request_start_time = osutils.timer_func()
 
637
        self._write_args(args)
 
638
        self._request.finished_writing()
 
639
        self._last_verb = args[0]
 
640
 
 
641
    def call_with_body_bytes(self, args, body):
 
642
        """Make a remote call of args with body bytes 'body'.
 
643
 
 
644
        After calling this, call read_response_tuple to find the result out.
 
645
        """
 
646
        if 'hpss' in debug.debug_flags:
 
647
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
648
            if getattr(self._request._medium, '_path', None) is not None:
 
649
                mutter('                  (to %s)', self._request._medium._path)
 
650
            mutter('              %d bytes', len(body))
 
651
            self._request_start_time = osutils.timer_func()
 
652
            if 'hpssdetail' in debug.debug_flags:
 
653
                mutter('hpss body content: %s', body)
 
654
        self._write_args(args)
 
655
        bytes = self._encode_bulk_data(body)
 
656
        self._request.accept_bytes(bytes)
 
657
        self._request.finished_writing()
 
658
        self._last_verb = args[0]
 
659
 
 
660
    def call_with_body_readv_array(self, args, body):
 
661
        """Make a remote call with a readv array.
 
662
 
 
663
        The body is encoded with one line per readv offset pair. The numbers in
 
664
        each pair are separated by a comma, and no trailing \\n is emitted.
 
665
        """
 
666
        if 'hpss' in debug.debug_flags:
 
667
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
668
            if getattr(self._request._medium, '_path', None) is not None:
 
669
                mutter('                  (to %s)', self._request._medium._path)
 
670
            self._request_start_time = osutils.timer_func()
 
671
        self._write_args(args)
 
672
        readv_bytes = self._serialise_offsets(body)
 
673
        bytes = self._encode_bulk_data(readv_bytes)
 
674
        self._request.accept_bytes(bytes)
 
675
        self._request.finished_writing()
 
676
        if 'hpss' in debug.debug_flags:
 
677
            mutter('              %d bytes in readv request', len(readv_bytes))
 
678
        self._last_verb = args[0]
 
679
 
 
680
    def call_with_body_stream(self, args, stream):
 
681
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
682
        # assume that a v1/v2 server doesn't support whatever method we're
 
683
        # trying to call with a body stream.
 
684
        self._request.finished_writing()
 
685
        self._request.finished_reading()
 
686
        raise errors.UnknownSmartMethod(args[0])
 
687
 
 
688
    def cancel_read_body(self):
 
689
        """After expecting a body, a response code may indicate one otherwise.
 
690
 
 
691
        This method lets the domain client inform the protocol that no body
 
692
        will be transmitted. This is a terminal method: after calling it the
 
693
        protocol is not able to be used further.
 
694
        """
 
695
        self._request.finished_reading()
 
696
 
 
697
    def _read_response_tuple(self):
 
698
        result = self._recv_tuple()
 
699
        if 'hpss' in debug.debug_flags:
 
700
            if self._request_start_time is not None:
 
701
                mutter('   result:   %6.3fs  %s',
 
702
                       osutils.timer_func() - self._request_start_time,
 
703
                       repr(result)[1:-1])
 
704
                self._request_start_time = None
 
705
            else:
 
706
                mutter('   result:   %s', repr(result)[1:-1])
 
707
        return result
 
708
 
 
709
    def read_response_tuple(self, expect_body=False):
 
710
        """Read a response tuple from the wire.
 
711
 
 
712
        This should only be called once.
 
713
        """
 
714
        result = self._read_response_tuple()
 
715
        self._response_is_unknown_method(result)
 
716
        self._raise_args_if_error(result)
 
717
        if not expect_body:
 
718
            self._request.finished_reading()
 
719
        return result
 
720
 
 
721
    def _raise_args_if_error(self, result_tuple):
 
722
        # Later protocol versions have an explicit flag in the protocol to say
 
723
        # if an error response is "failed" or not.  In version 1 we don't have
 
724
        # that luxury.  So here is a complete list of errors that can be
 
725
        # returned in response to existing version 1 smart requests.  Responses
 
726
        # starting with these codes are always "failed" responses.
 
727
        v1_error_codes = [
 
728
            'norepository',
 
729
            'NoSuchFile',
 
730
            'FileExists',
 
731
            'DirectoryNotEmpty',
 
732
            'ShortReadvError',
 
733
            'UnicodeEncodeError',
 
734
            'UnicodeDecodeError',
 
735
            'ReadOnlyError',
 
736
            'nobranch',
 
737
            'NoSuchRevision',
 
738
            'nosuchrevision',
 
739
            'LockContention',
 
740
            'UnlockableTransport',
 
741
            'LockFailed',
 
742
            'TokenMismatch',
 
743
            'ReadError',
 
744
            'PermissionDenied',
 
745
            ]
 
746
        if result_tuple[0] in v1_error_codes:
 
747
            self._request.finished_reading()
 
748
            raise errors.ErrorFromSmartServer(result_tuple)
 
749
 
 
750
    def _response_is_unknown_method(self, result_tuple):
 
751
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
752
        method' response to the request.
 
753
 
 
754
        :param response: The response from a smart client call_expecting_body
 
755
            call.
 
756
        :param verb: The verb used in that call.
 
757
        :raises: UnexpectedSmartServerResponse
 
758
        """
 
759
        if (result_tuple == ('error', "Generic bzr smart protocol error: "
 
760
                "bad request '%s'" % self._last_verb) or
 
761
              result_tuple == ('error', "Generic bzr smart protocol error: "
 
762
                "bad request u'%s'" % self._last_verb)):
 
763
            # The response will have no body, so we've finished reading.
 
764
            self._request.finished_reading()
 
765
            raise errors.UnknownSmartMethod(self._last_verb)
 
766
 
 
767
    def read_body_bytes(self, count=-1):
 
768
        """Read bytes from the body, decoding into a byte stream.
 
769
 
 
770
        We read all bytes at once to ensure we've checked the trailer for
 
771
        errors, and then feed the buffer back as read_body_bytes is called.
 
772
        """
 
773
        if self._body_buffer is not None:
 
774
            return self._body_buffer.read(count)
 
775
        _body_decoder = LengthPrefixedBodyDecoder()
 
776
 
 
777
        while not _body_decoder.finished_reading:
 
778
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
779
            if bytes == '':
 
780
                # end of file encountered reading from server
 
781
                raise errors.ConnectionReset(
 
782
                    "Connection lost while reading response body.")
 
783
            _body_decoder.accept_bytes(bytes)
 
784
        self._request.finished_reading()
 
785
        self._body_buffer = BytesIO(_body_decoder.read_pending_data())
 
786
        # XXX: TODO check the trailer result.
 
787
        if 'hpss' in debug.debug_flags:
 
788
            mutter('              %d body bytes read',
 
789
                   len(self._body_buffer.getvalue()))
 
790
        return self._body_buffer.read(count)
 
791
 
 
792
    def _recv_tuple(self):
 
793
        """Receive a tuple from the medium request."""
 
794
        return _decode_tuple(self._request.read_line())
 
795
 
 
796
    def query_version(self):
 
797
        """Return protocol version number of the server."""
 
798
        self.call('hello')
 
799
        resp = self.read_response_tuple()
 
800
        if resp == ('ok', '1'):
 
801
            return 1
 
802
        elif resp == ('ok', '2'):
 
803
            return 2
 
804
        else:
 
805
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
806
 
 
807
    def _write_args(self, args):
 
808
        self._write_protocol_version()
 
809
        bytes = _encode_tuple(args)
 
810
        self._request.accept_bytes(bytes)
 
811
 
 
812
    def _write_protocol_version(self):
 
813
        """Write any prefixes this protocol requires.
 
814
 
 
815
        Version one doesn't send protocol versions.
 
816
        """
 
817
 
 
818
 
 
819
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
820
    """Version two of the client side of the smart protocol.
 
821
 
 
822
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
823
    """
 
824
 
 
825
    response_marker = RESPONSE_VERSION_TWO
 
826
    request_marker = REQUEST_VERSION_TWO
 
827
 
 
828
    def read_response_tuple(self, expect_body=False):
 
829
        """Read a response tuple from the wire.
 
830
 
 
831
        This should only be called once.
 
832
        """
 
833
        version = self._request.read_line()
 
834
        if version != self.response_marker:
 
835
            self._request.finished_reading()
 
836
            raise errors.UnexpectedProtocolVersionMarker(version)
 
837
        response_status = self._request.read_line()
 
838
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
839
        self._response_is_unknown_method(result)
 
840
        if response_status == 'success\n':
 
841
            self.response_status = True
 
842
            if not expect_body:
 
843
                self._request.finished_reading()
 
844
            return result
 
845
        elif response_status == 'failed\n':
 
846
            self.response_status = False
 
847
            self._request.finished_reading()
 
848
            raise errors.ErrorFromSmartServer(result)
 
849
        else:
 
850
            raise errors.SmartProtocolError(
 
851
                'bad protocol status %r' % response_status)
 
852
 
 
853
    def _write_protocol_version(self):
 
854
        """Write any prefixes this protocol requires.
 
855
 
 
856
        Version two sends the value of REQUEST_VERSION_TWO.
 
857
        """
 
858
        self._request.accept_bytes(self.request_marker)
 
859
 
 
860
    def read_streamed_body(self):
 
861
        """Read bytes from the body, decoding into a byte stream.
 
862
        """
 
863
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
864
        # buffer space available) on Windows.
 
865
        _body_decoder = ChunkedBodyDecoder()
 
866
        while not _body_decoder.finished_reading:
 
867
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
868
            if bytes == '':
 
869
                # end of file encountered reading from server
 
870
                raise errors.ConnectionReset(
 
871
                    "Connection lost while reading streamed body.")
 
872
            _body_decoder.accept_bytes(bytes)
 
873
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
874
                if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
 
875
                    mutter('              %d byte chunk read',
 
876
                           len(body_bytes))
 
877
                yield body_bytes
 
878
        self._request.finished_reading()
 
879
 
 
880
 
 
881
def build_server_protocol_three(backing_transport, write_func,
 
882
                                root_client_path, jail_root=None):
 
883
    request_handler = request.SmartServerRequestHandler(
 
884
        backing_transport, commands=request.request_handlers,
 
885
        root_client_path=root_client_path, jail_root=jail_root)
 
886
    responder = ProtocolThreeResponder(write_func)
 
887
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
 
888
    return ProtocolThreeDecoder(message_handler)
 
889
 
 
890
 
 
891
class ProtocolThreeDecoder(_StatefulDecoder):
 
892
 
 
893
    response_marker = RESPONSE_VERSION_THREE
 
894
    request_marker = REQUEST_VERSION_THREE
 
895
 
 
896
    def __init__(self, message_handler, expect_version_marker=False):
 
897
        _StatefulDecoder.__init__(self)
 
898
        self._has_dispatched = False
 
899
        # Initial state
 
900
        if expect_version_marker:
 
901
            self.state_accept = self._state_accept_expecting_protocol_version
 
902
            # We're expecting at least the protocol version marker + some
 
903
            # headers.
 
904
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
905
        else:
 
906
            self.state_accept = self._state_accept_expecting_headers
 
907
            self._number_needed_bytes = 4
 
908
        self.decoding_failed = False
 
909
        self.request_handler = self.message_handler = message_handler
 
910
 
 
911
    def accept_bytes(self, bytes):
 
912
        self._number_needed_bytes = None
 
913
        try:
 
914
            _StatefulDecoder.accept_bytes(self, bytes)
 
915
        except KeyboardInterrupt:
 
916
            raise
 
917
        except errors.SmartMessageHandlerError as exception:
 
918
            # We do *not* set self.decoding_failed here.  The message handler
 
919
            # has raised an error, but the decoder is still able to parse bytes
 
920
            # and determine when this message ends.
 
921
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
922
                log_exception_quietly()
 
923
            self.message_handler.protocol_error(exception.exc_value)
 
924
            # The state machine is ready to continue decoding, but the
 
925
            # exception has interrupted the loop that runs the state machine.
 
926
            # So we call accept_bytes again to restart it.
 
927
            self.accept_bytes('')
 
928
        except Exception as exception:
 
929
            # The decoder itself has raised an exception.  We cannot continue
 
930
            # decoding.
 
931
            self.decoding_failed = True
 
932
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
933
                # This happens during normal operation when the client tries a
 
934
                # protocol version the server doesn't understand, so no need to
 
935
                # log a traceback every time.
 
936
                # Note that this can only happen when
 
937
                # expect_version_marker=True, which is only the case on the
 
938
                # client side.
 
939
                pass
 
940
            else:
 
941
                log_exception_quietly()
 
942
            self.message_handler.protocol_error(exception)
 
943
 
 
944
    def _extract_length_prefixed_bytes(self):
 
945
        if self._in_buffer_len < 4:
 
946
            # A length prefix by itself is 4 bytes, and we don't even have that
 
947
            # many yet.
 
948
            raise _NeedMoreBytes(4)
 
949
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
950
        end_of_bytes = 4 + length
 
951
        if self._in_buffer_len < end_of_bytes:
 
952
            # We haven't yet read as many bytes as the length-prefix says there
 
953
            # are.
 
954
            raise _NeedMoreBytes(end_of_bytes)
 
955
        # Extract the bytes from the buffer.
 
956
        in_buf = self._get_in_buffer()
 
957
        bytes = in_buf[4:end_of_bytes]
 
958
        self._set_in_buffer(in_buf[end_of_bytes:])
 
959
        return bytes
 
960
 
 
961
    def _extract_prefixed_bencoded_data(self):
 
962
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
963
        try:
 
964
            decoded = bdecode_as_tuple(prefixed_bytes)
 
965
        except ValueError:
 
966
            raise errors.SmartProtocolError(
 
967
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
968
        return decoded
 
969
 
 
970
    def _extract_single_byte(self):
 
971
        if self._in_buffer_len == 0:
 
972
            # The buffer is empty
 
973
            raise _NeedMoreBytes(1)
 
974
        in_buf = self._get_in_buffer()
 
975
        one_byte = in_buf[0]
 
976
        self._set_in_buffer(in_buf[1:])
 
977
        return one_byte
 
978
 
 
979
    def _state_accept_expecting_protocol_version(self):
 
980
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
981
        in_buf = self._get_in_buffer()
 
982
        if needed_bytes > 0:
 
983
            # We don't have enough bytes to check if the protocol version
 
984
            # marker is right.  But we can check if it is already wrong by
 
985
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
986
            # we've read so far.
 
987
            # [In fact, if the remote end isn't bzr we might never receive
 
988
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
989
            # are wrong then we should just raise immediately rather than
 
990
            # stall.]
 
991
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
992
                # We have enough bytes to know the protocol version is wrong
 
993
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
994
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
995
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
996
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
997
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
998
        self.state_accept = self._state_accept_expecting_headers
 
999
 
 
1000
    def _state_accept_expecting_headers(self):
 
1001
        decoded = self._extract_prefixed_bencoded_data()
 
1002
        if not isinstance(decoded, dict):
 
1003
            raise errors.SmartProtocolError(
 
1004
                'Header object %r is not a dict' % (decoded,))
 
1005
        self.state_accept = self._state_accept_expecting_message_part
 
1006
        try:
 
1007
            self.message_handler.headers_received(decoded)
 
1008
        except:
 
1009
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1010
 
 
1011
    def _state_accept_expecting_message_part(self):
 
1012
        message_part_kind = self._extract_single_byte()
 
1013
        if message_part_kind == 'o':
 
1014
            self.state_accept = self._state_accept_expecting_one_byte
 
1015
        elif message_part_kind == 's':
 
1016
            self.state_accept = self._state_accept_expecting_structure
 
1017
        elif message_part_kind == 'b':
 
1018
            self.state_accept = self._state_accept_expecting_bytes
 
1019
        elif message_part_kind == 'e':
 
1020
            self.done()
 
1021
        else:
 
1022
            raise errors.SmartProtocolError(
 
1023
                'Bad message kind byte: %r' % (message_part_kind,))
 
1024
 
 
1025
    def _state_accept_expecting_one_byte(self):
 
1026
        byte = self._extract_single_byte()
 
1027
        self.state_accept = self._state_accept_expecting_message_part
 
1028
        try:
 
1029
            self.message_handler.byte_part_received(byte)
 
1030
        except:
 
1031
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1032
 
 
1033
    def _state_accept_expecting_bytes(self):
 
1034
        # XXX: this should not buffer whole message part, but instead deliver
 
1035
        # the bytes as they arrive.
 
1036
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1037
        self.state_accept = self._state_accept_expecting_message_part
 
1038
        try:
 
1039
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1040
        except:
 
1041
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1042
 
 
1043
    def _state_accept_expecting_structure(self):
 
1044
        structure = self._extract_prefixed_bencoded_data()
 
1045
        self.state_accept = self._state_accept_expecting_message_part
 
1046
        try:
 
1047
            self.message_handler.structure_part_received(structure)
 
1048
        except:
 
1049
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1050
 
 
1051
    def done(self):
 
1052
        self.unused_data = self._get_in_buffer()
 
1053
        self._set_in_buffer(None)
 
1054
        self.state_accept = self._state_accept_reading_unused
 
1055
        try:
 
1056
            self.message_handler.end_received()
 
1057
        except:
 
1058
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1059
 
 
1060
    def _state_accept_reading_unused(self):
 
1061
        self.unused_data += self._get_in_buffer()
 
1062
        self._set_in_buffer(None)
 
1063
 
 
1064
    def next_read_size(self):
 
1065
        if self.state_accept == self._state_accept_reading_unused:
 
1066
            return 0
 
1067
        elif self.decoding_failed:
 
1068
            # An exception occured while processing this message, probably from
 
1069
            # self.message_handler.  We're not sure that this state machine is
 
1070
            # in a consistent state, so just signal that we're done (i.e. give
 
1071
            # up).
 
1072
            return 0
 
1073
        else:
 
1074
            if self._number_needed_bytes is not None:
 
1075
                return self._number_needed_bytes - self._in_buffer_len
 
1076
            else:
 
1077
                raise AssertionError("don't know how many bytes are expected!")
 
1078
 
 
1079
 
 
1080
class _ProtocolThreeEncoder(object):
 
1081
 
 
1082
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1083
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
 
1084
 
 
1085
    def __init__(self, write_func):
 
1086
        self._buf = []
 
1087
        self._buf_len = 0
 
1088
        self._real_write_func = write_func
 
1089
 
 
1090
    def _write_func(self, bytes):
 
1091
        # TODO: Another possibility would be to turn this into an async model.
 
1092
        #       Where we let another thread know that we have some bytes if
 
1093
        #       they want it, but we don't actually block for it
 
1094
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1095
        #       we might just push out smaller bits at a time?
 
1096
        self._buf.append(bytes)
 
1097
        self._buf_len += len(bytes)
 
1098
        if self._buf_len > self.BUFFER_SIZE:
 
1099
            self.flush()
 
1100
 
 
1101
    def flush(self):
 
1102
        if self._buf:
 
1103
            self._real_write_func(''.join(self._buf))
 
1104
            del self._buf[:]
 
1105
            self._buf_len = 0
 
1106
 
 
1107
    def _serialise_offsets(self, offsets):
 
1108
        """Serialise a readv offset list."""
 
1109
        txt = []
 
1110
        for start, length in offsets:
 
1111
            txt.append('%d,%d' % (start, length))
 
1112
        return '\n'.join(txt)
 
1113
 
 
1114
    def _write_protocol_version(self):
 
1115
        self._write_func(MESSAGE_VERSION_THREE)
 
1116
 
 
1117
    def _write_prefixed_bencode(self, structure):
 
1118
        bytes = bencode(structure)
 
1119
        self._write_func(struct.pack('!L', len(bytes)))
 
1120
        self._write_func(bytes)
 
1121
 
 
1122
    def _write_headers(self, headers):
 
1123
        self._write_prefixed_bencode(headers)
 
1124
 
 
1125
    def _write_structure(self, args):
 
1126
        self._write_func('s')
 
1127
        utf8_args = []
 
1128
        for arg in args:
 
1129
            if isinstance(arg, unicode):
 
1130
                utf8_args.append(arg.encode('utf8'))
 
1131
            else:
 
1132
                utf8_args.append(arg)
 
1133
        self._write_prefixed_bencode(utf8_args)
 
1134
 
 
1135
    def _write_end(self):
 
1136
        self._write_func('e')
 
1137
        self.flush()
 
1138
 
 
1139
    def _write_prefixed_body(self, bytes):
 
1140
        self._write_func('b')
 
1141
        self._write_func(struct.pack('!L', len(bytes)))
 
1142
        self._write_func(bytes)
 
1143
 
 
1144
    def _write_chunked_body_start(self):
 
1145
        self._write_func('oC')
 
1146
 
 
1147
    def _write_error_status(self):
 
1148
        self._write_func('oE')
 
1149
 
 
1150
    def _write_success_status(self):
 
1151
        self._write_func('oS')
 
1152
 
 
1153
 
 
1154
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1155
 
 
1156
    def __init__(self, write_func):
 
1157
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1158
        self.response_sent = False
 
1159
        self._headers = {'Software version': breezy.__version__}
 
1160
        if 'hpss' in debug.debug_flags:
 
1161
            self._thread_id = _thread.get_ident()
 
1162
            self._response_start_time = None
 
1163
 
 
1164
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1165
        if self._response_start_time is None:
 
1166
            self._response_start_time = osutils.timer_func()
 
1167
        if include_time:
 
1168
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1169
        else:
 
1170
            t = ''
 
1171
        if extra_bytes is None:
 
1172
            extra = ''
 
1173
        else:
 
1174
            extra = ' ' + repr(extra_bytes[:40])
 
1175
            if len(extra) > 33:
 
1176
                extra = extra[:29] + extra[-1] + '...'
 
1177
        mutter('%12s: [%s] %s%s%s'
 
1178
               % (action, self._thread_id, t, message, extra))
 
1179
 
 
1180
    def send_error(self, exception):
 
1181
        if self.response_sent:
 
1182
            raise AssertionError(
 
1183
                "send_error(%s) called, but response already sent."
 
1184
                % (exception,))
 
1185
        if isinstance(exception, errors.UnknownSmartMethod):
 
1186
            failure = request.FailedSmartServerResponse(
 
1187
                ('UnknownMethod', exception.verb))
 
1188
            self.send_response(failure)
 
1189
            return
 
1190
        if 'hpss' in debug.debug_flags:
 
1191
            self._trace('error', str(exception))
 
1192
        self.response_sent = True
 
1193
        self._write_protocol_version()
 
1194
        self._write_headers(self._headers)
 
1195
        self._write_error_status()
 
1196
        self._write_structure(('error', str(exception)))
 
1197
        self._write_end()
 
1198
 
 
1199
    def send_response(self, response):
 
1200
        if self.response_sent:
 
1201
            raise AssertionError(
 
1202
                "send_response(%r) called, but response already sent."
 
1203
                % (response,))
 
1204
        self.response_sent = True
 
1205
        self._write_protocol_version()
 
1206
        self._write_headers(self._headers)
 
1207
        if response.is_successful():
 
1208
            self._write_success_status()
 
1209
        else:
 
1210
            self._write_error_status()
 
1211
        if 'hpss' in debug.debug_flags:
 
1212
            self._trace('response', repr(response.args))
 
1213
        self._write_structure(response.args)
 
1214
        if response.body is not None:
 
1215
            self._write_prefixed_body(response.body)
 
1216
            if 'hpss' in debug.debug_flags:
 
1217
                self._trace('body', '%d bytes' % (len(response.body),),
 
1218
                            response.body, include_time=True)
 
1219
        elif response.body_stream is not None:
 
1220
            count = num_bytes = 0
 
1221
            first_chunk = None
 
1222
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1223
                count += 1
 
1224
                if exc_info is not None:
 
1225
                    self._write_error_status()
 
1226
                    error_struct = request._translate_error(exc_info[1])
 
1227
                    self._write_structure(error_struct)
 
1228
                    break
 
1229
                else:
 
1230
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1231
                        self._write_error_status()
 
1232
                        self._write_structure(chunk.args)
 
1233
                        break
 
1234
                    num_bytes += len(chunk)
 
1235
                    if first_chunk is None:
 
1236
                        first_chunk = chunk
 
1237
                    self._write_prefixed_body(chunk)
 
1238
                    self.flush()
 
1239
                    if 'hpssdetail' in debug.debug_flags:
 
1240
                        # Not worth timing separately, as _write_func is
 
1241
                        # actually buffered
 
1242
                        self._trace('body chunk',
 
1243
                                    '%d bytes' % (len(chunk),),
 
1244
                                    chunk, suppress_time=True)
 
1245
            if 'hpss' in debug.debug_flags:
 
1246
                self._trace('body stream',
 
1247
                            '%d bytes %d chunks' % (num_bytes, count),
 
1248
                            first_chunk)
 
1249
        self._write_end()
 
1250
        if 'hpss' in debug.debug_flags:
 
1251
            self._trace('response end', '', include_time=True)
 
1252
 
 
1253
 
 
1254
def _iter_with_errors(iterable):
 
1255
    """Handle errors from iterable.next().
 
1256
 
 
1257
    Use like::
 
1258
 
 
1259
        for exc_info, value in _iter_with_errors(iterable):
 
1260
            ...
 
1261
 
 
1262
    This is a safer alternative to::
 
1263
 
 
1264
        try:
 
1265
            for value in iterable:
 
1266
               ...
 
1267
        except:
 
1268
            ...
 
1269
 
 
1270
    Because the latter will catch errors from the for-loop body, not just
 
1271
    iterable.next()
 
1272
 
 
1273
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1274
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1275
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1276
    will not be itercepted.
 
1277
    """
 
1278
    iterator = iter(iterable)
 
1279
    while True:
 
1280
        try:
 
1281
            yield None, next(iterator)
 
1282
        except StopIteration:
 
1283
            return
 
1284
        except (KeyboardInterrupt, SystemExit):
 
1285
            raise
 
1286
        except Exception:
 
1287
            mutter('_iter_with_errors caught error')
 
1288
            log_exception_quietly()
 
1289
            yield sys.exc_info(), None
 
1290
            return
 
1291
 
 
1292
 
 
1293
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1294
 
 
1295
    def __init__(self, medium_request):
 
1296
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1297
        self._medium_request = medium_request
 
1298
        self._headers = {}
 
1299
        self.body_stream_started = None
 
1300
 
 
1301
    def set_headers(self, headers):
 
1302
        self._headers = headers.copy()
 
1303
 
 
1304
    def call(self, *args):
 
1305
        if 'hpss' in debug.debug_flags:
 
1306
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1307
            base = getattr(self._medium_request._medium, 'base', None)
 
1308
            if base is not None:
 
1309
                mutter('             (to %s)', base)
 
1310
            self._request_start_time = osutils.timer_func()
 
1311
        self._write_protocol_version()
 
1312
        self._write_headers(self._headers)
 
1313
        self._write_structure(args)
 
1314
        self._write_end()
 
1315
        self._medium_request.finished_writing()
 
1316
 
 
1317
    def call_with_body_bytes(self, args, body):
 
1318
        """Make a remote call of args with body bytes 'body'.
 
1319
 
 
1320
        After calling this, call read_response_tuple to find the result out.
 
1321
        """
 
1322
        if 'hpss' in debug.debug_flags:
 
1323
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1324
            path = getattr(self._medium_request._medium, '_path', None)
 
1325
            if path is not None:
 
1326
                mutter('                  (to %s)', path)
 
1327
            mutter('              %d bytes', len(body))
 
1328
            self._request_start_time = osutils.timer_func()
 
1329
        self._write_protocol_version()
 
1330
        self._write_headers(self._headers)
 
1331
        self._write_structure(args)
 
1332
        self._write_prefixed_body(body)
 
1333
        self._write_end()
 
1334
        self._medium_request.finished_writing()
 
1335
 
 
1336
    def call_with_body_readv_array(self, args, body):
 
1337
        """Make a remote call with a readv array.
 
1338
 
 
1339
        The body is encoded with one line per readv offset pair. The numbers in
 
1340
        each pair are separated by a comma, and no trailing \\n is emitted.
 
1341
        """
 
1342
        if 'hpss' in debug.debug_flags:
 
1343
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1344
            path = getattr(self._medium_request._medium, '_path', None)
 
1345
            if path is not None:
 
1346
                mutter('                  (to %s)', path)
 
1347
            self._request_start_time = osutils.timer_func()
 
1348
        self._write_protocol_version()
 
1349
        self._write_headers(self._headers)
 
1350
        self._write_structure(args)
 
1351
        readv_bytes = self._serialise_offsets(body)
 
1352
        if 'hpss' in debug.debug_flags:
 
1353
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1354
        self._write_prefixed_body(readv_bytes)
 
1355
        self._write_end()
 
1356
        self._medium_request.finished_writing()
 
1357
 
 
1358
    def call_with_body_stream(self, args, stream):
 
1359
        if 'hpss' in debug.debug_flags:
 
1360
            mutter('hpss call w/body stream: %r', args)
 
1361
            path = getattr(self._medium_request._medium, '_path', None)
 
1362
            if path is not None:
 
1363
                mutter('                  (to %s)', path)
 
1364
            self._request_start_time = osutils.timer_func()
 
1365
        self.body_stream_started = False
 
1366
        self._write_protocol_version()
 
1367
        self._write_headers(self._headers)
 
1368
        self._write_structure(args)
 
1369
        # TODO: notice if the server has sent an early error reply before we
 
1370
        #       have finished sending the stream.  We would notice at the end
 
1371
        #       anyway, but if the medium can deliver it early then it's good
 
1372
        #       to short-circuit the whole request...
 
1373
        # Provoke any ConnectionReset failures before we start the body stream.
 
1374
        self.flush()
 
1375
        self.body_stream_started = True
 
1376
        for exc_info, part in _iter_with_errors(stream):
 
1377
            if exc_info is not None:
 
1378
                # Iterating the stream failed.  Cleanly abort the request.
 
1379
                self._write_error_status()
 
1380
                # Currently the client unconditionally sends ('error',) as the
 
1381
                # error args.
 
1382
                self._write_structure(('error',))
 
1383
                self._write_end()
 
1384
                self._medium_request.finished_writing()
 
1385
                try:
 
1386
                    reraise(*exc_info)
 
1387
                finally:
 
1388
                    del exc_info
 
1389
            else:
 
1390
                self._write_prefixed_body(part)
 
1391
                self.flush()
 
1392
        self._write_end()
 
1393
        self._medium_request.finished_writing()
 
1394