/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/bzr/smart/protocol.py

  • Committer: Jelmer Vernooij
  • Date: 2018-05-19 13:16:11 UTC
  • mto: (6968.4.3 git-archive)
  • mto: This revision was merged to the branch mainline in revision 6972.
  • Revision ID: jelmer@jelmer.uk-20180519131611-l9h9ud41j7qg1m03
Move tar/zip to breezy.archive.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Wire-level encoding and decoding of requests and responses for the smart
 
18
client and server.
 
19
"""
 
20
 
 
21
from __future__ import absolute_import
 
22
 
 
23
import collections
 
24
import struct
 
25
import sys
 
26
try:
 
27
    import _thread
 
28
except ImportError:
 
29
    import thread as _thread
 
30
import time
 
31
 
 
32
import breezy
 
33
from ... import (
 
34
    debug,
 
35
    errors,
 
36
    osutils,
 
37
    )
 
38
from ...sixish import (
 
39
    BytesIO,
 
40
    reraise,
 
41
)
 
42
from . import message, request
 
43
from ...sixish import text_type
 
44
from ...trace import log_exception_quietly, mutter
 
45
from ...bencode import bdecode_as_tuple, bencode
 
46
 
 
47
 
 
48
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
49
# responses to identify the protocol version being used. (There are no version
 
50
# one strings because that version doesn't send any).
 
51
REQUEST_VERSION_TWO = b'bzr request 2\n'
 
52
RESPONSE_VERSION_TWO = b'bzr response 2\n'
 
53
 
 
54
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
 
55
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
 
56
 
 
57
 
 
58
def _recv_tuple(from_file):
 
59
    req_line = from_file.readline()
 
60
    return _decode_tuple(req_line)
 
61
 
 
62
 
 
63
def _decode_tuple(req_line):
 
64
    if req_line is None or req_line == b'':
 
65
        return None
 
66
    if not req_line.endswith(b'\n'):
 
67
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
68
    return tuple(req_line[:-1].split(b'\x01'))
 
69
 
 
70
 
 
71
def _encode_tuple(args):
 
72
    """Encode the tuple args to a bytestream."""
 
73
    for arg in args:
 
74
        if isinstance(arg, text_type):
 
75
            raise TypeError(args)
 
76
    return b'\x01'.join(args) + b'\n'
 
77
 
 
78
 
 
79
class Requester(object):
 
80
    """Abstract base class for an object that can issue requests on a smart
 
81
    medium.
 
82
    """
 
83
 
 
84
    def call(self, *args):
 
85
        """Make a remote call.
 
86
 
 
87
        :param args: the arguments of this call.
 
88
        """
 
89
        raise NotImplementedError(self.call)
 
90
 
 
91
    def call_with_body_bytes(self, args, body):
 
92
        """Make a remote call with a body.
 
93
 
 
94
        :param args: the arguments of this call.
 
95
        :type body: str
 
96
        :param body: the body to send with the request.
 
97
        """
 
98
        raise NotImplementedError(self.call_with_body_bytes)
 
99
 
 
100
    def call_with_body_readv_array(self, args, body):
 
101
        """Make a remote call with a readv array.
 
102
 
 
103
        :param args: the arguments of this call.
 
104
        :type body: iterable of (start, length) tuples.
 
105
        :param body: the readv ranges to send with this request.
 
106
        """
 
107
        raise NotImplementedError(self.call_with_body_readv_array)
 
108
 
 
109
    def set_headers(self, headers):
 
110
        raise NotImplementedError(self.set_headers)
 
111
 
 
112
 
 
113
class SmartProtocolBase(object):
 
114
    """Methods common to client and server"""
 
115
 
 
116
    # TODO: this only actually accomodates a single block; possibly should
 
117
    # support multiple chunks?
 
118
    def _encode_bulk_data(self, body):
 
119
        """Encode body as a bulk data chunk."""
 
120
        return b''.join((b'%d\n' % len(body), body, b'done\n'))
 
121
 
 
122
    def _serialise_offsets(self, offsets):
 
123
        """Serialise a readv offset list."""
 
124
        txt = []
 
125
        for start, length in offsets:
 
126
            txt.append(b'%d,%d' % (start, length))
 
127
        return b'\n'.join(txt)
 
128
 
 
129
 
 
130
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
131
    """Server-side encoding and decoding logic for smart version 1."""
 
132
 
 
133
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
134
            jail_root=None):
 
135
        self._backing_transport = backing_transport
 
136
        self._root_client_path = root_client_path
 
137
        self._jail_root = jail_root
 
138
        self.unused_data = b''
 
139
        self._finished = False
 
140
        self.in_buffer = b''
 
141
        self._has_dispatched = False
 
142
        self.request = None
 
143
        self._body_decoder = None
 
144
        self._write_func = write_func
 
145
 
 
146
    def accept_bytes(self, data):
 
147
        """Take bytes, and advance the internal state machine appropriately.
 
148
 
 
149
        :param data: must be a byte string
 
150
        """
 
151
        if not isinstance(data, bytes):
 
152
            raise ValueError(data)
 
153
        self.in_buffer += data
 
154
        if not self._has_dispatched:
 
155
            if b'\n' not in self.in_buffer:
 
156
                # no command line yet
 
157
                return
 
158
            self._has_dispatched = True
 
159
            try:
 
160
                first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
 
161
                first_line += b'\n'
 
162
                req_args = _decode_tuple(first_line)
 
163
                self.request = request.SmartServerRequestHandler(
 
164
                    self._backing_transport, commands=request.request_handlers,
 
165
                    root_client_path=self._root_client_path,
 
166
                    jail_root=self._jail_root)
 
167
                self.request.args_received(req_args)
 
168
                if self.request.finished_reading:
 
169
                    # trivial request
 
170
                    self.unused_data = self.in_buffer
 
171
                    self.in_buffer = b''
 
172
                    self._send_response(self.request.response)
 
173
            except KeyboardInterrupt:
 
174
                raise
 
175
            except errors.UnknownSmartMethod as err:
 
176
                protocol_error = errors.SmartProtocolError(
 
177
                    "bad request %r" % (err.verb,))
 
178
                failure = request.FailedSmartServerResponse(
 
179
                    (b'error', str(protocol_error).encode('utf-8')))
 
180
                self._send_response(failure)
 
181
                return
 
182
            except Exception as exception:
 
183
                # everything else: pass to client, flush, and quit
 
184
                log_exception_quietly()
 
185
                self._send_response(request.FailedSmartServerResponse(
 
186
                    (b'error', str(exception).encode('utf-8'))))
 
187
                return
 
188
 
 
189
        if self._has_dispatched:
 
190
            if self._finished:
 
191
                # nothing to do.XXX: this routine should be a single state
 
192
                # machine too.
 
193
                self.unused_data += self.in_buffer
 
194
                self.in_buffer = b''
 
195
                return
 
196
            if self._body_decoder is None:
 
197
                self._body_decoder = LengthPrefixedBodyDecoder()
 
198
            self._body_decoder.accept_bytes(self.in_buffer)
 
199
            self.in_buffer = self._body_decoder.unused_data
 
200
            body_data = self._body_decoder.read_pending_data()
 
201
            self.request.accept_body(body_data)
 
202
            if self._body_decoder.finished_reading:
 
203
                self.request.end_of_body()
 
204
                if not self.request.finished_reading:
 
205
                    raise AssertionError("no more body, request not finished")
 
206
            if self.request.response is not None:
 
207
                self._send_response(self.request.response)
 
208
                self.unused_data = self.in_buffer
 
209
                self.in_buffer = b''
 
210
            else:
 
211
                if self.request.finished_reading:
 
212
                    raise AssertionError(
 
213
                        "no response and we have finished reading.")
 
214
 
 
215
    def _send_response(self, response):
 
216
        """Send a smart server response down the output stream."""
 
217
        if self._finished:
 
218
            raise AssertionError('response already sent')
 
219
        args = response.args
 
220
        body = response.body
 
221
        self._finished = True
 
222
        self._write_protocol_version()
 
223
        self._write_success_or_failure_prefix(response)
 
224
        self._write_func(_encode_tuple(args))
 
225
        if body is not None:
 
226
            if not isinstance(body, str):
 
227
                raise ValueError(body)
 
228
            bytes = self._encode_bulk_data(body)
 
229
            self._write_func(bytes)
 
230
 
 
231
    def _write_protocol_version(self):
 
232
        """Write any prefixes this protocol requires.
 
233
 
 
234
        Version one doesn't send protocol versions.
 
235
        """
 
236
 
 
237
    def _write_success_or_failure_prefix(self, response):
 
238
        """Write the protocol specific success/failure prefix.
 
239
 
 
240
        For SmartServerRequestProtocolOne this is omitted but we
 
241
        call is_successful to ensure that the response is valid.
 
242
        """
 
243
        response.is_successful()
 
244
 
 
245
    def next_read_size(self):
 
246
        if self._finished:
 
247
            return 0
 
248
        if self._body_decoder is None:
 
249
            return 1
 
250
        else:
 
251
            return self._body_decoder.next_read_size()
 
252
 
 
253
 
 
254
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
255
    r"""Version two of the server side of the smart protocol.
 
256
 
 
257
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
258
    """
 
259
 
 
260
    response_marker = RESPONSE_VERSION_TWO
 
261
    request_marker = REQUEST_VERSION_TWO
 
262
 
 
263
    def _write_success_or_failure_prefix(self, response):
 
264
        """Write the protocol specific success/failure prefix."""
 
265
        if response.is_successful():
 
266
            self._write_func(b'success\n')
 
267
        else:
 
268
            self._write_func(b'failed\n')
 
269
 
 
270
    def _write_protocol_version(self):
 
271
        r"""Write any prefixes this protocol requires.
 
272
 
 
273
        Version two sends the value of RESPONSE_VERSION_TWO.
 
274
        """
 
275
        self._write_func(self.response_marker)
 
276
 
 
277
    def _send_response(self, response):
 
278
        """Send a smart server response down the output stream."""
 
279
        if (self._finished):
 
280
            raise AssertionError('response already sent')
 
281
        self._finished = True
 
282
        self._write_protocol_version()
 
283
        self._write_success_or_failure_prefix(response)
 
284
        self._write_func(_encode_tuple(response.args))
 
285
        if response.body is not None:
 
286
            if not isinstance(response.body, bytes):
 
287
                raise AssertionError('body must be bytes')
 
288
            if not (response.body_stream is None):
 
289
                raise AssertionError(
 
290
                    'body_stream and body cannot both be set')
 
291
            data = self._encode_bulk_data(response.body)
 
292
            self._write_func(data)
 
293
        elif response.body_stream is not None:
 
294
            _send_stream(response.body_stream, self._write_func)
 
295
 
 
296
 
 
297
def _send_stream(stream, write_func):
 
298
    write_func(b'chunked\n')
 
299
    _send_chunks(stream, write_func)
 
300
    write_func(b'END\n')
 
301
 
 
302
 
 
303
def _send_chunks(stream, write_func):
 
304
    for chunk in stream:
 
305
        if isinstance(chunk, bytes):
 
306
            data = ("%x\n" % len(chunk)).encode('ascii') + chunk
 
307
            write_func(data)
 
308
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
309
            write_func(b'ERR\n')
 
310
            _send_chunks(chunk.args, write_func)
 
311
            return
 
312
        else:
 
313
            raise errors.BzrError(
 
314
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
315
                % chunk)
 
316
 
 
317
 
 
318
class _NeedMoreBytes(Exception):
 
319
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
320
    have been received.
 
321
    """
 
322
 
 
323
    def __init__(self, count=None):
 
324
        """Constructor.
 
325
 
 
326
        :param count: the total number of bytes needed by the current state.
 
327
            May be None if the number of bytes needed is unknown.
 
328
        """
 
329
        self.count = count
 
330
 
 
331
 
 
332
class _StatefulDecoder(object):
 
333
    """Base class for writing state machines to decode byte streams.
 
334
 
 
335
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
336
    and, if appropriate, updates self.state_accept to a different function.
 
337
    accept_bytes will call state_accept as often as necessary to make sure the
 
338
    state machine has progressed as far as possible before it returns.
 
339
 
 
340
    See ProtocolThreeDecoder for an example subclass.
 
341
    """
 
342
 
 
343
    def __init__(self):
 
344
        self.finished_reading = False
 
345
        self._in_buffer_list = []
 
346
        self._in_buffer_len = 0
 
347
        self.unused_data = b''
 
348
        self.bytes_left = None
 
349
        self._number_needed_bytes = None
 
350
 
 
351
    def _get_in_buffer(self):
 
352
        if len(self._in_buffer_list) == 1:
 
353
            return self._in_buffer_list[0]
 
354
        in_buffer = b''.join(self._in_buffer_list)
 
355
        if len(in_buffer) != self._in_buffer_len:
 
356
            raise AssertionError(
 
357
                "Length of buffer did not match expected value: %s != %s"
 
358
                % self._in_buffer_len, len(in_buffer))
 
359
        self._in_buffer_list = [in_buffer]
 
360
        return in_buffer
 
361
 
 
362
    def _get_in_bytes(self, count):
 
363
        """Grab X bytes from the input_buffer.
 
364
 
 
365
        Callers should have already checked that self._in_buffer_len is >
 
366
        count. Note, this does not consume the bytes from the buffer. The
 
367
        caller will still need to call _get_in_buffer() and then
 
368
        _set_in_buffer() if they actually need to consume the bytes.
 
369
        """
 
370
        # check if we can yield the bytes from just the first entry in our list
 
371
        if len(self._in_buffer_list) == 0:
 
372
            raise AssertionError('Callers must be sure we have buffered bytes'
 
373
                ' before calling _get_in_bytes')
 
374
        if len(self._in_buffer_list[0]) > count:
 
375
            return self._in_buffer_list[0][:count]
 
376
        # We can't yield it from the first buffer, so collapse all buffers, and
 
377
        # yield it from that
 
378
        in_buf = self._get_in_buffer()
 
379
        return in_buf[:count]
 
380
 
 
381
    def _set_in_buffer(self, new_buf):
 
382
        if new_buf is not None:
 
383
            self._in_buffer_list = [new_buf]
 
384
            self._in_buffer_len = len(new_buf)
 
385
        else:
 
386
            self._in_buffer_list = []
 
387
            self._in_buffer_len = 0
 
388
 
 
389
    def accept_bytes(self, bytes):
 
390
        """Decode as much of bytes as possible.
 
391
 
 
392
        If 'bytes' contains too much data it will be appended to
 
393
        self.unused_data.
 
394
 
 
395
        finished_reading will be set when no more data is required.  Further
 
396
        data will be appended to self.unused_data.
 
397
        """
 
398
        # accept_bytes is allowed to change the state
 
399
        self._number_needed_bytes = None
 
400
        # lsprof puts a very large amount of time on this specific call for
 
401
        # large readv arrays
 
402
        self._in_buffer_list.append(bytes)
 
403
        self._in_buffer_len += len(bytes)
 
404
        try:
 
405
            # Run the function for the current state.
 
406
            current_state = self.state_accept
 
407
            self.state_accept()
 
408
            while current_state != self.state_accept:
 
409
                # The current state has changed.  Run the function for the new
 
410
                # current state, so that it can:
 
411
                #   - decode any unconsumed bytes left in a buffer, and
 
412
                #   - signal how many more bytes are expected (via raising
 
413
                #     _NeedMoreBytes).
 
414
                current_state = self.state_accept
 
415
                self.state_accept()
 
416
        except _NeedMoreBytes as e:
 
417
            self._number_needed_bytes = e.count
 
418
 
 
419
 
 
420
class ChunkedBodyDecoder(_StatefulDecoder):
 
421
    """Decoder for chunked body data.
 
422
 
 
423
    This is very similar the HTTP's chunked encoding.  See the description of
 
424
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
425
    """
 
426
 
 
427
    def __init__(self):
 
428
        _StatefulDecoder.__init__(self)
 
429
        self.state_accept = self._state_accept_expecting_header
 
430
        self.chunk_in_progress = None
 
431
        self.chunks = collections.deque()
 
432
        self.error = False
 
433
        self.error_in_progress = None
 
434
 
 
435
    def next_read_size(self):
 
436
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
437
        # end-of-body marker is 4 bytes: 'END\n'.
 
438
        if self.state_accept == self._state_accept_reading_chunk:
 
439
            # We're expecting more chunk content.  So we're expecting at least
 
440
            # the rest of this chunk plus an END chunk.
 
441
            return self.bytes_left + 4
 
442
        elif self.state_accept == self._state_accept_expecting_length:
 
443
            if self._in_buffer_len == 0:
 
444
                # We're expecting a chunk length.  There's at least two bytes
 
445
                # left: a digit plus '\n'.
 
446
                return 2
 
447
            else:
 
448
                # We're in the middle of reading a chunk length.  So there's at
 
449
                # least one byte left, the '\n' that terminates the length.
 
450
                return 1
 
451
        elif self.state_accept == self._state_accept_reading_unused:
 
452
            return 1
 
453
        elif self.state_accept == self._state_accept_expecting_header:
 
454
            return max(0, len('chunked\n') - self._in_buffer_len)
 
455
        else:
 
456
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
457
 
 
458
    def read_next_chunk(self):
 
459
        try:
 
460
            return self.chunks.popleft()
 
461
        except IndexError:
 
462
            return None
 
463
 
 
464
    def _extract_line(self):
 
465
        in_buf = self._get_in_buffer()
 
466
        pos = in_buf.find(b'\n')
 
467
        if pos == -1:
 
468
            # We haven't read a complete line yet, so request more bytes before
 
469
            # we continue.
 
470
            raise _NeedMoreBytes(1)
 
471
        line = in_buf[:pos]
 
472
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
473
        self._set_in_buffer(in_buf[pos+1:])
 
474
        return line
 
475
 
 
476
    def _finished(self):
 
477
        self.unused_data = self._get_in_buffer()
 
478
        self._in_buffer_list = []
 
479
        self._in_buffer_len = 0
 
480
        self.state_accept = self._state_accept_reading_unused
 
481
        if self.error:
 
482
            error_args = tuple(self.error_in_progress)
 
483
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
484
            self.error_in_progress = None
 
485
        self.finished_reading = True
 
486
 
 
487
    def _state_accept_expecting_header(self):
 
488
        prefix = self._extract_line()
 
489
        if prefix == b'chunked':
 
490
            self.state_accept = self._state_accept_expecting_length
 
491
        else:
 
492
            raise errors.SmartProtocolError(
 
493
                'Bad chunked body header: "%s"' % (prefix,))
 
494
 
 
495
    def _state_accept_expecting_length(self):
 
496
        prefix = self._extract_line()
 
497
        if prefix == b'ERR':
 
498
            self.error = True
 
499
            self.error_in_progress = []
 
500
            self._state_accept_expecting_length()
 
501
            return
 
502
        elif prefix == b'END':
 
503
            # We've read the end-of-body marker.
 
504
            # Any further bytes are unused data, including the bytes left in
 
505
            # the _in_buffer.
 
506
            self._finished()
 
507
            return
 
508
        else:
 
509
            self.bytes_left = int(prefix, 16)
 
510
            self.chunk_in_progress = b''
 
511
            self.state_accept = self._state_accept_reading_chunk
 
512
 
 
513
    def _state_accept_reading_chunk(self):
 
514
        in_buf = self._get_in_buffer()
 
515
        in_buffer_len = len(in_buf)
 
516
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
517
        self._set_in_buffer(in_buf[self.bytes_left:])
 
518
        self.bytes_left -= in_buffer_len
 
519
        if self.bytes_left <= 0:
 
520
            # Finished with chunk
 
521
            self.bytes_left = None
 
522
            if self.error:
 
523
                self.error_in_progress.append(self.chunk_in_progress)
 
524
            else:
 
525
                self.chunks.append(self.chunk_in_progress)
 
526
            self.chunk_in_progress = None
 
527
            self.state_accept = self._state_accept_expecting_length
 
528
 
 
529
    def _state_accept_reading_unused(self):
 
530
        self.unused_data += self._get_in_buffer()
 
531
        self._in_buffer_list = []
 
532
 
 
533
 
 
534
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
535
    """Decodes the length-prefixed bulk data."""
 
536
 
 
537
    def __init__(self):
 
538
        _StatefulDecoder.__init__(self)
 
539
        self.state_accept = self._state_accept_expecting_length
 
540
        self.state_read = self._state_read_no_data
 
541
        self._body = b''
 
542
        self._trailer_buffer = b''
 
543
 
 
544
    def next_read_size(self):
 
545
        if self.bytes_left is not None:
 
546
            # Ideally we want to read all the remainder of the body and the
 
547
            # trailer in one go.
 
548
            return self.bytes_left + 5
 
549
        elif self.state_accept == self._state_accept_reading_trailer:
 
550
            # Just the trailer left
 
551
            return 5 - len(self._trailer_buffer)
 
552
        elif self.state_accept == self._state_accept_expecting_length:
 
553
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
554
            # 'done\n').
 
555
            return 6
 
556
        else:
 
557
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
558
            return 1
 
559
 
 
560
    def read_pending_data(self):
 
561
        """Return any pending data that has been decoded."""
 
562
        return self.state_read()
 
563
 
 
564
    def _state_accept_expecting_length(self):
 
565
        in_buf = self._get_in_buffer()
 
566
        pos = in_buf.find(b'\n')
 
567
        if pos == -1:
 
568
            return
 
569
        self.bytes_left = int(in_buf[:pos])
 
570
        self._set_in_buffer(in_buf[pos+1:])
 
571
        self.state_accept = self._state_accept_reading_body
 
572
        self.state_read = self._state_read_body_buffer
 
573
 
 
574
    def _state_accept_reading_body(self):
 
575
        in_buf = self._get_in_buffer()
 
576
        self._body += in_buf
 
577
        self.bytes_left -= len(in_buf)
 
578
        self._set_in_buffer(None)
 
579
        if self.bytes_left <= 0:
 
580
            # Finished with body
 
581
            if self.bytes_left != 0:
 
582
                self._trailer_buffer = self._body[self.bytes_left:]
 
583
                self._body = self._body[:self.bytes_left]
 
584
            self.bytes_left = None
 
585
            self.state_accept = self._state_accept_reading_trailer
 
586
 
 
587
    def _state_accept_reading_trailer(self):
 
588
        self._trailer_buffer += self._get_in_buffer()
 
589
        self._set_in_buffer(None)
 
590
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
591
        # a ProtocolViolation exception?
 
592
        if self._trailer_buffer.startswith(b'done\n'):
 
593
            self.unused_data = self._trailer_buffer[len(b'done\n'):]
 
594
            self.state_accept = self._state_accept_reading_unused
 
595
            self.finished_reading = True
 
596
 
 
597
    def _state_accept_reading_unused(self):
 
598
        self.unused_data += self._get_in_buffer()
 
599
        self._set_in_buffer(None)
 
600
 
 
601
    def _state_read_no_data(self):
 
602
        return b''
 
603
 
 
604
    def _state_read_body_buffer(self):
 
605
        result = self._body
 
606
        self._body = b''
 
607
        return result
 
608
 
 
609
 
 
610
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
611
                                    message.ResponseHandler):
 
612
    """The client-side protocol for smart version 1."""
 
613
 
 
614
    def __init__(self, request):
 
615
        """Construct a SmartClientRequestProtocolOne.
 
616
 
 
617
        :param request: A SmartClientMediumRequest to serialise onto and
 
618
            deserialise from.
 
619
        """
 
620
        self._request = request
 
621
        self._body_buffer = None
 
622
        self._request_start_time = None
 
623
        self._last_verb = None
 
624
        self._headers = None
 
625
 
 
626
    def set_headers(self, headers):
 
627
        self._headers = dict(headers)
 
628
 
 
629
    def call(self, *args):
 
630
        if 'hpss' in debug.debug_flags:
 
631
            mutter('hpss call:   %s', repr(args)[1:-1])
 
632
            if getattr(self._request._medium, 'base', None) is not None:
 
633
                mutter('             (to %s)', self._request._medium.base)
 
634
            self._request_start_time = osutils.timer_func()
 
635
        self._write_args(args)
 
636
        self._request.finished_writing()
 
637
        self._last_verb = args[0]
 
638
 
 
639
    def call_with_body_bytes(self, args, body):
 
640
        """Make a remote call of args with body bytes 'body'.
 
641
 
 
642
        After calling this, call read_response_tuple to find the result out.
 
643
        """
 
644
        if 'hpss' in debug.debug_flags:
 
645
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
646
            if getattr(self._request._medium, '_path', None) is not None:
 
647
                mutter('                  (to %s)', self._request._medium._path)
 
648
            mutter('              %d bytes', len(body))
 
649
            self._request_start_time = osutils.timer_func()
 
650
            if 'hpssdetail' in debug.debug_flags:
 
651
                mutter('hpss body content: %s', body)
 
652
        self._write_args(args)
 
653
        bytes = self._encode_bulk_data(body)
 
654
        self._request.accept_bytes(bytes)
 
655
        self._request.finished_writing()
 
656
        self._last_verb = args[0]
 
657
 
 
658
    def call_with_body_readv_array(self, args, body):
 
659
        """Make a remote call with a readv array.
 
660
 
 
661
        The body is encoded with one line per readv offset pair. The numbers in
 
662
        each pair are separated by a comma, and no trailing \\n is emitted.
 
663
        """
 
664
        if 'hpss' in debug.debug_flags:
 
665
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
666
            if getattr(self._request._medium, '_path', None) is not None:
 
667
                mutter('                  (to %s)', self._request._medium._path)
 
668
            self._request_start_time = osutils.timer_func()
 
669
        self._write_args(args)
 
670
        readv_bytes = self._serialise_offsets(body)
 
671
        bytes = self._encode_bulk_data(readv_bytes)
 
672
        self._request.accept_bytes(bytes)
 
673
        self._request.finished_writing()
 
674
        if 'hpss' in debug.debug_flags:
 
675
            mutter('              %d bytes in readv request', len(readv_bytes))
 
676
        self._last_verb = args[0]
 
677
 
 
678
    def call_with_body_stream(self, args, stream):
 
679
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
680
        # assume that a v1/v2 server doesn't support whatever method we're
 
681
        # trying to call with a body stream.
 
682
        self._request.finished_writing()
 
683
        self._request.finished_reading()
 
684
        raise errors.UnknownSmartMethod(args[0])
 
685
 
 
686
    def cancel_read_body(self):
 
687
        """After expecting a body, a response code may indicate one otherwise.
 
688
 
 
689
        This method lets the domain client inform the protocol that no body
 
690
        will be transmitted. This is a terminal method: after calling it the
 
691
        protocol is not able to be used further.
 
692
        """
 
693
        self._request.finished_reading()
 
694
 
 
695
    def _read_response_tuple(self):
 
696
        result = self._recv_tuple()
 
697
        if 'hpss' in debug.debug_flags:
 
698
            if self._request_start_time is not None:
 
699
                mutter('   result:   %6.3fs  %s',
 
700
                       osutils.timer_func() - self._request_start_time,
 
701
                       repr(result)[1:-1])
 
702
                self._request_start_time = None
 
703
            else:
 
704
                mutter('   result:   %s', repr(result)[1:-1])
 
705
        return result
 
706
 
 
707
    def read_response_tuple(self, expect_body=False):
 
708
        """Read a response tuple from the wire.
 
709
 
 
710
        This should only be called once.
 
711
        """
 
712
        result = self._read_response_tuple()
 
713
        self._response_is_unknown_method(result)
 
714
        self._raise_args_if_error(result)
 
715
        if not expect_body:
 
716
            self._request.finished_reading()
 
717
        return result
 
718
 
 
719
    def _raise_args_if_error(self, result_tuple):
 
720
        # Later protocol versions have an explicit flag in the protocol to say
 
721
        # if an error response is "failed" or not.  In version 1 we don't have
 
722
        # that luxury.  So here is a complete list of errors that can be
 
723
        # returned in response to existing version 1 smart requests.  Responses
 
724
        # starting with these codes are always "failed" responses.
 
725
        v1_error_codes = [
 
726
            b'norepository',
 
727
            b'NoSuchFile',
 
728
            b'FileExists',
 
729
            b'DirectoryNotEmpty',
 
730
            b'ShortReadvError',
 
731
            b'UnicodeEncodeError',
 
732
            b'UnicodeDecodeError',
 
733
            b'ReadOnlyError',
 
734
            b'nobranch',
 
735
            b'NoSuchRevision',
 
736
            b'nosuchrevision',
 
737
            b'LockContention',
 
738
            b'UnlockableTransport',
 
739
            b'LockFailed',
 
740
            b'TokenMismatch',
 
741
            b'ReadError',
 
742
            b'PermissionDenied',
 
743
            ]
 
744
        if result_tuple[0] in v1_error_codes:
 
745
            self._request.finished_reading()
 
746
            raise errors.ErrorFromSmartServer(result_tuple)
 
747
 
 
748
    def _response_is_unknown_method(self, result_tuple):
 
749
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
750
        method' response to the request.
 
751
 
 
752
        :param response: The response from a smart client call_expecting_body
 
753
            call.
 
754
        :param verb: The verb used in that call.
 
755
        :raises: UnexpectedSmartServerResponse
 
756
        """
 
757
        if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
 
758
                b"bad request '" + self._last_verb + b"'") or
 
759
              result_tuple == (b'error', b"Generic bzr smart protocol error: "
 
760
                b"bad request u'%s'" % self._last_verb)):
 
761
            # The response will have no body, so we've finished reading.
 
762
            self._request.finished_reading()
 
763
            raise errors.UnknownSmartMethod(self._last_verb)
 
764
 
 
765
    def read_body_bytes(self, count=-1):
 
766
        """Read bytes from the body, decoding into a byte stream.
 
767
 
 
768
        We read all bytes at once to ensure we've checked the trailer for
 
769
        errors, and then feed the buffer back as read_body_bytes is called.
 
770
        """
 
771
        if self._body_buffer is not None:
 
772
            return self._body_buffer.read(count)
 
773
        _body_decoder = LengthPrefixedBodyDecoder()
 
774
 
 
775
        while not _body_decoder.finished_reading:
 
776
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
777
            if bytes == b'':
 
778
                # end of file encountered reading from server
 
779
                raise errors.ConnectionReset(
 
780
                    "Connection lost while reading response body.")
 
781
            _body_decoder.accept_bytes(bytes)
 
782
        self._request.finished_reading()
 
783
        self._body_buffer = BytesIO(_body_decoder.read_pending_data())
 
784
        # XXX: TODO check the trailer result.
 
785
        if 'hpss' in debug.debug_flags:
 
786
            mutter('              %d body bytes read',
 
787
                   len(self._body_buffer.getvalue()))
 
788
        return self._body_buffer.read(count)
 
789
 
 
790
    def _recv_tuple(self):
 
791
        """Receive a tuple from the medium request."""
 
792
        return _decode_tuple(self._request.read_line())
 
793
 
 
794
    def query_version(self):
 
795
        """Return protocol version number of the server."""
 
796
        self.call(b'hello')
 
797
        resp = self.read_response_tuple()
 
798
        if resp == (b'ok', '1'):
 
799
            return 1
 
800
        elif resp == (b'ok', '2'):
 
801
            return 2
 
802
        else:
 
803
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
804
 
 
805
    def _write_args(self, args):
 
806
        self._write_protocol_version()
 
807
        bytes = _encode_tuple(args)
 
808
        self._request.accept_bytes(bytes)
 
809
 
 
810
    def _write_protocol_version(self):
 
811
        """Write any prefixes this protocol requires.
 
812
 
 
813
        Version one doesn't send protocol versions.
 
814
        """
 
815
 
 
816
 
 
817
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
818
    """Version two of the client side of the smart protocol.
 
819
 
 
820
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
821
    """
 
822
 
 
823
    response_marker = RESPONSE_VERSION_TWO
 
824
    request_marker = REQUEST_VERSION_TWO
 
825
 
 
826
    def read_response_tuple(self, expect_body=False):
 
827
        """Read a response tuple from the wire.
 
828
 
 
829
        This should only be called once.
 
830
        """
 
831
        version = self._request.read_line()
 
832
        if version != self.response_marker:
 
833
            self._request.finished_reading()
 
834
            raise errors.UnexpectedProtocolVersionMarker(version)
 
835
        response_status = self._request.read_line()
 
836
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
837
        self._response_is_unknown_method(result)
 
838
        if response_status == b'success\n':
 
839
            self.response_status = True
 
840
            if not expect_body:
 
841
                self._request.finished_reading()
 
842
            return result
 
843
        elif response_status == b'failed\n':
 
844
            self.response_status = False
 
845
            self._request.finished_reading()
 
846
            raise errors.ErrorFromSmartServer(result)
 
847
        else:
 
848
            raise errors.SmartProtocolError(
 
849
                'bad protocol status %r' % response_status)
 
850
 
 
851
    def _write_protocol_version(self):
 
852
        """Write any prefixes this protocol requires.
 
853
 
 
854
        Version two sends the value of REQUEST_VERSION_TWO.
 
855
        """
 
856
        self._request.accept_bytes(self.request_marker)
 
857
 
 
858
    def read_streamed_body(self):
 
859
        """Read bytes from the body, decoding into a byte stream.
 
860
        """
 
861
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
862
        # buffer space available) on Windows.
 
863
        _body_decoder = ChunkedBodyDecoder()
 
864
        while not _body_decoder.finished_reading:
 
865
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
866
            if bytes == b'':
 
867
                # end of file encountered reading from server
 
868
                raise errors.ConnectionReset(
 
869
                    "Connection lost while reading streamed body.")
 
870
            _body_decoder.accept_bytes(bytes)
 
871
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
872
                if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
 
873
                    mutter('              %d byte chunk read',
 
874
                           len(body_bytes))
 
875
                yield body_bytes
 
876
        self._request.finished_reading()
 
877
 
 
878
 
 
879
def build_server_protocol_three(backing_transport, write_func,
 
880
                                root_client_path, jail_root=None):
 
881
    request_handler = request.SmartServerRequestHandler(
 
882
        backing_transport, commands=request.request_handlers,
 
883
        root_client_path=root_client_path, jail_root=jail_root)
 
884
    responder = ProtocolThreeResponder(write_func)
 
885
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
 
886
    return ProtocolThreeDecoder(message_handler)
 
887
 
 
888
 
 
889
class ProtocolThreeDecoder(_StatefulDecoder):
 
890
 
 
891
    response_marker = RESPONSE_VERSION_THREE
 
892
    request_marker = REQUEST_VERSION_THREE
 
893
 
 
894
    def __init__(self, message_handler, expect_version_marker=False):
 
895
        _StatefulDecoder.__init__(self)
 
896
        self._has_dispatched = False
 
897
        # Initial state
 
898
        if expect_version_marker:
 
899
            self.state_accept = self._state_accept_expecting_protocol_version
 
900
            # We're expecting at least the protocol version marker + some
 
901
            # headers.
 
902
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
903
        else:
 
904
            self.state_accept = self._state_accept_expecting_headers
 
905
            self._number_needed_bytes = 4
 
906
        self.decoding_failed = False
 
907
        self.request_handler = self.message_handler = message_handler
 
908
 
 
909
    def accept_bytes(self, bytes):
 
910
        self._number_needed_bytes = None
 
911
        try:
 
912
            _StatefulDecoder.accept_bytes(self, bytes)
 
913
        except KeyboardInterrupt:
 
914
            raise
 
915
        except errors.SmartMessageHandlerError as exception:
 
916
            # We do *not* set self.decoding_failed here.  The message handler
 
917
            # has raised an error, but the decoder is still able to parse bytes
 
918
            # and determine when this message ends.
 
919
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
920
                log_exception_quietly()
 
921
            self.message_handler.protocol_error(exception.exc_value)
 
922
            # The state machine is ready to continue decoding, but the
 
923
            # exception has interrupted the loop that runs the state machine.
 
924
            # So we call accept_bytes again to restart it.
 
925
            self.accept_bytes('')
 
926
        except Exception as exception:
 
927
            # The decoder itself has raised an exception.  We cannot continue
 
928
            # decoding.
 
929
            self.decoding_failed = True
 
930
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
931
                # This happens during normal operation when the client tries a
 
932
                # protocol version the server doesn't understand, so no need to
 
933
                # log a traceback every time.
 
934
                # Note that this can only happen when
 
935
                # expect_version_marker=True, which is only the case on the
 
936
                # client side.
 
937
                pass
 
938
            else:
 
939
                log_exception_quietly()
 
940
            self.message_handler.protocol_error(exception)
 
941
 
 
942
    def _extract_length_prefixed_bytes(self):
 
943
        if self._in_buffer_len < 4:
 
944
            # A length prefix by itself is 4 bytes, and we don't even have that
 
945
            # many yet.
 
946
            raise _NeedMoreBytes(4)
 
947
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
948
        end_of_bytes = 4 + length
 
949
        if self._in_buffer_len < end_of_bytes:
 
950
            # We haven't yet read as many bytes as the length-prefix says there
 
951
            # are.
 
952
            raise _NeedMoreBytes(end_of_bytes)
 
953
        # Extract the bytes from the buffer.
 
954
        in_buf = self._get_in_buffer()
 
955
        bytes = in_buf[4:end_of_bytes]
 
956
        self._set_in_buffer(in_buf[end_of_bytes:])
 
957
        return bytes
 
958
 
 
959
    def _extract_prefixed_bencoded_data(self):
 
960
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
961
        try:
 
962
            decoded = bdecode_as_tuple(prefixed_bytes)
 
963
        except ValueError:
 
964
            raise errors.SmartProtocolError(
 
965
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
966
        return decoded
 
967
 
 
968
    def _extract_single_byte(self):
 
969
        if self._in_buffer_len == 0:
 
970
            # The buffer is empty
 
971
            raise _NeedMoreBytes(1)
 
972
        in_buf = self._get_in_buffer()
 
973
        one_byte = in_buf[0:1]
 
974
        self._set_in_buffer(in_buf[1:])
 
975
        return one_byte
 
976
 
 
977
    def _state_accept_expecting_protocol_version(self):
 
978
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
979
        in_buf = self._get_in_buffer()
 
980
        if needed_bytes > 0:
 
981
            # We don't have enough bytes to check if the protocol version
 
982
            # marker is right.  But we can check if it is already wrong by
 
983
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
984
            # we've read so far.
 
985
            # [In fact, if the remote end isn't bzr we might never receive
 
986
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
987
            # are wrong then we should just raise immediately rather than
 
988
            # stall.]
 
989
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
990
                # We have enough bytes to know the protocol version is wrong
 
991
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
992
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
993
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
994
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
995
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
996
        self.state_accept = self._state_accept_expecting_headers
 
997
 
 
998
    def _state_accept_expecting_headers(self):
 
999
        decoded = self._extract_prefixed_bencoded_data()
 
1000
        if not isinstance(decoded, dict):
 
1001
            raise errors.SmartProtocolError(
 
1002
                'Header object %r is not a dict' % (decoded,))
 
1003
        self.state_accept = self._state_accept_expecting_message_part
 
1004
        try:
 
1005
            self.message_handler.headers_received(decoded)
 
1006
        except:
 
1007
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1008
 
 
1009
    def _state_accept_expecting_message_part(self):
 
1010
        message_part_kind = self._extract_single_byte()
 
1011
        if message_part_kind == b'o':
 
1012
            self.state_accept = self._state_accept_expecting_one_byte
 
1013
        elif message_part_kind == b's':
 
1014
            self.state_accept = self._state_accept_expecting_structure
 
1015
        elif message_part_kind == b'b':
 
1016
            self.state_accept = self._state_accept_expecting_bytes
 
1017
        elif message_part_kind == b'e':
 
1018
            self.done()
 
1019
        else:
 
1020
            raise errors.SmartProtocolError(
 
1021
                'Bad message kind byte: %r' % (message_part_kind,))
 
1022
 
 
1023
    def _state_accept_expecting_one_byte(self):
 
1024
        byte = self._extract_single_byte()
 
1025
        self.state_accept = self._state_accept_expecting_message_part
 
1026
        try:
 
1027
            self.message_handler.byte_part_received(byte)
 
1028
        except:
 
1029
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1030
 
 
1031
    def _state_accept_expecting_bytes(self):
 
1032
        # XXX: this should not buffer whole message part, but instead deliver
 
1033
        # the bytes as they arrive.
 
1034
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1035
        self.state_accept = self._state_accept_expecting_message_part
 
1036
        try:
 
1037
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1038
        except:
 
1039
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1040
 
 
1041
    def _state_accept_expecting_structure(self):
 
1042
        structure = self._extract_prefixed_bencoded_data()
 
1043
        self.state_accept = self._state_accept_expecting_message_part
 
1044
        try:
 
1045
            self.message_handler.structure_part_received(structure)
 
1046
        except:
 
1047
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1048
 
 
1049
    def done(self):
 
1050
        self.unused_data = self._get_in_buffer()
 
1051
        self._set_in_buffer(None)
 
1052
        self.state_accept = self._state_accept_reading_unused
 
1053
        try:
 
1054
            self.message_handler.end_received()
 
1055
        except:
 
1056
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1057
 
 
1058
    def _state_accept_reading_unused(self):
 
1059
        self.unused_data += self._get_in_buffer()
 
1060
        self._set_in_buffer(None)
 
1061
 
 
1062
    def next_read_size(self):
 
1063
        if self.state_accept == self._state_accept_reading_unused:
 
1064
            return 0
 
1065
        elif self.decoding_failed:
 
1066
            # An exception occured while processing this message, probably from
 
1067
            # self.message_handler.  We're not sure that this state machine is
 
1068
            # in a consistent state, so just signal that we're done (i.e. give
 
1069
            # up).
 
1070
            return 0
 
1071
        else:
 
1072
            if self._number_needed_bytes is not None:
 
1073
                return self._number_needed_bytes - self._in_buffer_len
 
1074
            else:
 
1075
                raise AssertionError("don't know how many bytes are expected!")
 
1076
 
 
1077
 
 
1078
class _ProtocolThreeEncoder(object):
 
1079
 
 
1080
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1081
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
 
1082
 
 
1083
    def __init__(self, write_func):
 
1084
        self._buf = []
 
1085
        self._buf_len = 0
 
1086
        self._real_write_func = write_func
 
1087
 
 
1088
    def _write_func(self, bytes):
 
1089
        # TODO: Another possibility would be to turn this into an async model.
 
1090
        #       Where we let another thread know that we have some bytes if
 
1091
        #       they want it, but we don't actually block for it
 
1092
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1093
        #       we might just push out smaller bits at a time?
 
1094
        self._buf.append(bytes)
 
1095
        self._buf_len += len(bytes)
 
1096
        if self._buf_len > self.BUFFER_SIZE:
 
1097
            self.flush()
 
1098
 
 
1099
    def flush(self):
 
1100
        if self._buf:
 
1101
            self._real_write_func(b''.join(self._buf))
 
1102
            del self._buf[:]
 
1103
            self._buf_len = 0
 
1104
 
 
1105
    def _serialise_offsets(self, offsets):
 
1106
        """Serialise a readv offset list."""
 
1107
        txt = []
 
1108
        for start, length in offsets:
 
1109
            txt.append(b'%d,%d' % (start, length))
 
1110
        return b'\n'.join(txt)
 
1111
 
 
1112
    def _write_protocol_version(self):
 
1113
        self._write_func(MESSAGE_VERSION_THREE)
 
1114
 
 
1115
    def _write_prefixed_bencode(self, structure):
 
1116
        bytes = bencode(structure)
 
1117
        self._write_func(struct.pack('!L', len(bytes)))
 
1118
        self._write_func(bytes)
 
1119
 
 
1120
    def _write_headers(self, headers):
 
1121
        self._write_prefixed_bencode(headers)
 
1122
 
 
1123
    def _write_structure(self, args):
 
1124
        self._write_func(b's')
 
1125
        utf8_args = []
 
1126
        for arg in args:
 
1127
            if isinstance(arg, text_type):
 
1128
                utf8_args.append(arg.encode('utf8'))
 
1129
            else:
 
1130
                utf8_args.append(arg)
 
1131
        self._write_prefixed_bencode(utf8_args)
 
1132
 
 
1133
    def _write_end(self):
 
1134
        self._write_func(b'e')
 
1135
        self.flush()
 
1136
 
 
1137
    def _write_prefixed_body(self, bytes):
 
1138
        self._write_func(b'b')
 
1139
        self._write_func(struct.pack('!L', len(bytes)))
 
1140
        self._write_func(bytes)
 
1141
 
 
1142
    def _write_chunked_body_start(self):
 
1143
        self._write_func(b'oC')
 
1144
 
 
1145
    def _write_error_status(self):
 
1146
        self._write_func(b'oE')
 
1147
 
 
1148
    def _write_success_status(self):
 
1149
        self._write_func(b'oS')
 
1150
 
 
1151
 
 
1152
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1153
 
 
1154
    def __init__(self, write_func):
 
1155
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1156
        self.response_sent = False
 
1157
        self._headers = {
 
1158
                b'Software version': breezy.__version__.encode('utf-8')}
 
1159
        if 'hpss' in debug.debug_flags:
 
1160
            self._thread_id = _thread.get_ident()
 
1161
            self._response_start_time = None
 
1162
 
 
1163
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1164
        if self._response_start_time is None:
 
1165
            self._response_start_time = osutils.timer_func()
 
1166
        if include_time:
 
1167
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1168
        else:
 
1169
            t = ''
 
1170
        if extra_bytes is None:
 
1171
            extra = ''
 
1172
        else:
 
1173
            extra = ' ' + repr(extra_bytes[:40])
 
1174
            if len(extra) > 33:
 
1175
                extra = extra[:29] + extra[-1] + '...'
 
1176
        mutter('%12s: [%s] %s%s%s'
 
1177
               % (action, self._thread_id, t, message, extra))
 
1178
 
 
1179
    def send_error(self, exception):
 
1180
        if self.response_sent:
 
1181
            raise AssertionError(
 
1182
                "send_error(%s) called, but response already sent."
 
1183
                % (exception,))
 
1184
        if isinstance(exception, errors.UnknownSmartMethod):
 
1185
            failure = request.FailedSmartServerResponse(
 
1186
                (b'UnknownMethod', exception.verb))
 
1187
            self.send_response(failure)
 
1188
            return
 
1189
        if 'hpss' in debug.debug_flags:
 
1190
            self._trace('error', str(exception))
 
1191
        self.response_sent = True
 
1192
        self._write_protocol_version()
 
1193
        self._write_headers(self._headers)
 
1194
        self._write_error_status()
 
1195
        self._write_structure((b'error', str(exception).encode('utf-8', 'replace')))
 
1196
        self._write_end()
 
1197
 
 
1198
    def send_response(self, response):
 
1199
        if self.response_sent:
 
1200
            raise AssertionError(
 
1201
                "send_response(%r) called, but response already sent."
 
1202
                % (response,))
 
1203
        self.response_sent = True
 
1204
        self._write_protocol_version()
 
1205
        self._write_headers(self._headers)
 
1206
        if response.is_successful():
 
1207
            self._write_success_status()
 
1208
        else:
 
1209
            self._write_error_status()
 
1210
        if 'hpss' in debug.debug_flags:
 
1211
            self._trace('response', repr(response.args))
 
1212
        self._write_structure(response.args)
 
1213
        if response.body is not None:
 
1214
            self._write_prefixed_body(response.body)
 
1215
            if 'hpss' in debug.debug_flags:
 
1216
                self._trace('body', '%d bytes' % (len(response.body),),
 
1217
                            response.body, include_time=True)
 
1218
        elif response.body_stream is not None:
 
1219
            count = num_bytes = 0
 
1220
            first_chunk = None
 
1221
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1222
                count += 1
 
1223
                if exc_info is not None:
 
1224
                    self._write_error_status()
 
1225
                    error_struct = request._translate_error(exc_info[1])
 
1226
                    self._write_structure(error_struct)
 
1227
                    break
 
1228
                else:
 
1229
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1230
                        self._write_error_status()
 
1231
                        self._write_structure(chunk.args)
 
1232
                        break
 
1233
                    num_bytes += len(chunk)
 
1234
                    if first_chunk is None:
 
1235
                        first_chunk = chunk
 
1236
                    self._write_prefixed_body(chunk)
 
1237
                    self.flush()
 
1238
                    if 'hpssdetail' in debug.debug_flags:
 
1239
                        # Not worth timing separately, as _write_func is
 
1240
                        # actually buffered
 
1241
                        self._trace('body chunk',
 
1242
                                    '%d bytes' % (len(chunk),),
 
1243
                                    chunk, suppress_time=True)
 
1244
            if 'hpss' in debug.debug_flags:
 
1245
                self._trace('body stream',
 
1246
                            '%d bytes %d chunks' % (num_bytes, count),
 
1247
                            first_chunk)
 
1248
        self._write_end()
 
1249
        if 'hpss' in debug.debug_flags:
 
1250
            self._trace('response end', '', include_time=True)
 
1251
 
 
1252
 
 
1253
def _iter_with_errors(iterable):
 
1254
    """Handle errors from iterable.next().
 
1255
 
 
1256
    Use like::
 
1257
 
 
1258
        for exc_info, value in _iter_with_errors(iterable):
 
1259
            ...
 
1260
 
 
1261
    This is a safer alternative to::
 
1262
 
 
1263
        try:
 
1264
            for value in iterable:
 
1265
               ...
 
1266
        except:
 
1267
            ...
 
1268
 
 
1269
    Because the latter will catch errors from the for-loop body, not just
 
1270
    iterable.next()
 
1271
 
 
1272
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1273
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1274
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1275
    will not be itercepted.
 
1276
    """
 
1277
    iterator = iter(iterable)
 
1278
    while True:
 
1279
        try:
 
1280
            yield None, next(iterator)
 
1281
        except StopIteration:
 
1282
            return
 
1283
        except (KeyboardInterrupt, SystemExit):
 
1284
            raise
 
1285
        except Exception:
 
1286
            mutter('_iter_with_errors caught error')
 
1287
            log_exception_quietly()
 
1288
            yield sys.exc_info(), None
 
1289
            return
 
1290
 
 
1291
 
 
1292
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1293
 
 
1294
    def __init__(self, medium_request):
 
1295
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1296
        self._medium_request = medium_request
 
1297
        self._headers = {}
 
1298
        self.body_stream_started = None
 
1299
 
 
1300
    def set_headers(self, headers):
 
1301
        self._headers = headers.copy()
 
1302
 
 
1303
    def call(self, *args):
 
1304
        if 'hpss' in debug.debug_flags:
 
1305
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1306
            base = getattr(self._medium_request._medium, 'base', None)
 
1307
            if base is not None:
 
1308
                mutter('             (to %s)', base)
 
1309
            self._request_start_time = osutils.timer_func()
 
1310
        self._write_protocol_version()
 
1311
        self._write_headers(self._headers)
 
1312
        self._write_structure(args)
 
1313
        self._write_end()
 
1314
        self._medium_request.finished_writing()
 
1315
 
 
1316
    def call_with_body_bytes(self, args, body):
 
1317
        """Make a remote call of args with body bytes 'body'.
 
1318
 
 
1319
        After calling this, call read_response_tuple to find the result out.
 
1320
        """
 
1321
        if 'hpss' in debug.debug_flags:
 
1322
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1323
            path = getattr(self._medium_request._medium, '_path', None)
 
1324
            if path is not None:
 
1325
                mutter('                  (to %s)', path)
 
1326
            mutter('              %d bytes', len(body))
 
1327
            self._request_start_time = osutils.timer_func()
 
1328
        self._write_protocol_version()
 
1329
        self._write_headers(self._headers)
 
1330
        self._write_structure(args)
 
1331
        self._write_prefixed_body(body)
 
1332
        self._write_end()
 
1333
        self._medium_request.finished_writing()
 
1334
 
 
1335
    def call_with_body_readv_array(self, args, body):
 
1336
        """Make a remote call with a readv array.
 
1337
 
 
1338
        The body is encoded with one line per readv offset pair. The numbers in
 
1339
        each pair are separated by a comma, and no trailing \\n is emitted.
 
1340
        """
 
1341
        if 'hpss' in debug.debug_flags:
 
1342
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1343
            path = getattr(self._medium_request._medium, '_path', None)
 
1344
            if path is not None:
 
1345
                mutter('                  (to %s)', path)
 
1346
            self._request_start_time = osutils.timer_func()
 
1347
        self._write_protocol_version()
 
1348
        self._write_headers(self._headers)
 
1349
        self._write_structure(args)
 
1350
        readv_bytes = self._serialise_offsets(body)
 
1351
        if 'hpss' in debug.debug_flags:
 
1352
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1353
        self._write_prefixed_body(readv_bytes)
 
1354
        self._write_end()
 
1355
        self._medium_request.finished_writing()
 
1356
 
 
1357
    def call_with_body_stream(self, args, stream):
 
1358
        if 'hpss' in debug.debug_flags:
 
1359
            mutter('hpss call w/body stream: %r', args)
 
1360
            path = getattr(self._medium_request._medium, '_path', None)
 
1361
            if path is not None:
 
1362
                mutter('                  (to %s)', path)
 
1363
            self._request_start_time = osutils.timer_func()
 
1364
        self.body_stream_started = False
 
1365
        self._write_protocol_version()
 
1366
        self._write_headers(self._headers)
 
1367
        self._write_structure(args)
 
1368
        # TODO: notice if the server has sent an early error reply before we
 
1369
        #       have finished sending the stream.  We would notice at the end
 
1370
        #       anyway, but if the medium can deliver it early then it's good
 
1371
        #       to short-circuit the whole request...
 
1372
        # Provoke any ConnectionReset failures before we start the body stream.
 
1373
        self.flush()
 
1374
        self.body_stream_started = True
 
1375
        for exc_info, part in _iter_with_errors(stream):
 
1376
            if exc_info is not None:
 
1377
                # Iterating the stream failed.  Cleanly abort the request.
 
1378
                self._write_error_status()
 
1379
                # Currently the client unconditionally sends ('error',) as the
 
1380
                # error args.
 
1381
                self._write_structure((b'error',))
 
1382
                self._write_end()
 
1383
                self._medium_request.finished_writing()
 
1384
                try:
 
1385
                    reraise(*exc_info)
 
1386
                finally:
 
1387
                    del exc_info
 
1388
            else:
 
1389
                self._write_prefixed_body(part)
 
1390
                self.flush()
 
1391
        self._write_end()
 
1392
        self._medium_request.finished_writing()
 
1393