/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: John Arbash Meinel
  • Date: 2008-07-08 14:55:19 UTC
  • mfrom: (3530 +trunk)
  • mto: This revision was merged to the branch mainline in revision 3532.
  • Revision ID: john@arbash-meinel.com-20080708145519-paqg4kjwbpgs2xmq
Merge bzr.dev 3530

Show diffs side-by-side

added added

removed removed

Lines of Context:
27
27
import os
28
28
import socket
29
29
import sys
 
30
import urllib
30
31
 
31
32
from bzrlib import (
32
33
    errors,
33
34
    osutils,
34
35
    symbol_versioning,
 
36
    urlutils,
35
37
    )
36
38
from bzrlib.smart.protocol import (
 
39
    MESSAGE_VERSION_THREE,
37
40
    REQUEST_VERSION_TWO,
 
41
    SmartClientRequestProtocolOne,
38
42
    SmartServerRequestProtocolOne,
39
43
    SmartServerRequestProtocolTwo,
 
44
    build_server_protocol_three
40
45
    )
41
46
from bzrlib.transport import ssh
42
47
 
43
48
 
 
49
def _get_protocol_factory_for_bytes(bytes):
 
50
    """Determine the right protocol factory for 'bytes'.
 
51
 
 
52
    This will return an appropriate protocol factory depending on the version
 
53
    of the protocol being used, as determined by inspecting the given bytes.
 
54
    The bytes should have at least one newline byte (i.e. be a whole line),
 
55
    otherwise it's possible that a request will be incorrectly identified as
 
56
    version 1.
 
57
 
 
58
    Typical use would be::
 
59
 
 
60
         factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
61
         server_protocol = factory(transport, write_func, root_client_path)
 
62
         server_protocol.accept_bytes(unused_bytes)
 
63
 
 
64
    :param bytes: a str of bytes of the start of the request.
 
65
    :returns: 2-tuple of (protocol_factory, unused_bytes).  protocol_factory is
 
66
        a callable that takes three args: transport, write_func,
 
67
        root_client_path.  unused_bytes are any bytes that were not part of a
 
68
        protocol version marker.
 
69
    """
 
70
    if bytes.startswith(MESSAGE_VERSION_THREE):
 
71
        protocol_factory = build_server_protocol_three
 
72
        bytes = bytes[len(MESSAGE_VERSION_THREE):]
 
73
    elif bytes.startswith(REQUEST_VERSION_TWO):
 
74
        protocol_factory = SmartServerRequestProtocolTwo
 
75
        bytes = bytes[len(REQUEST_VERSION_TWO):]
 
76
    else:
 
77
        protocol_factory = SmartServerRequestProtocolOne
 
78
    return protocol_factory, bytes
 
79
 
 
80
 
44
81
class SmartServerStreamMedium(object):
45
82
    """Handles smart commands coming over a stream.
46
83
 
52
89
 
53
90
    The server passes requests through to an underlying backing transport, 
54
91
    which will typically be a LocalTransport looking at the server's filesystem.
 
92
 
 
93
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
 
94
        but not used yet, or None if there are no buffered bytes.  Subclasses
 
95
        should make sure to exhaust this buffer before reading more bytes from
 
96
        the stream.  See also the _push_back method.
55
97
    """
56
98
 
57
 
    def __init__(self, backing_transport):
 
99
    def __init__(self, backing_transport, root_client_path='/'):
58
100
        """Construct new server.
59
101
 
60
102
        :param backing_transport: Transport for the directory served.
61
103
        """
62
104
        # backing_transport could be passed to serve instead of __init__
63
105
        self.backing_transport = backing_transport
 
106
        self.root_client_path = root_client_path
64
107
        self.finished = False
 
108
        self._push_back_buffer = None
 
109
 
 
110
    def _push_back(self, bytes):
 
111
        """Return unused bytes to the medium, because they belong to the next
 
112
        request(s).
 
113
 
 
114
        This sets the _push_back_buffer to the given bytes.
 
115
        """
 
116
        if self._push_back_buffer is not None:
 
117
            raise AssertionError(
 
118
                "_push_back called when self._push_back_buffer is %r"
 
119
                % (self._push_back_buffer,))
 
120
        if bytes == '':
 
121
            return
 
122
        self._push_back_buffer = bytes
 
123
 
 
124
    def _get_push_back_buffer(self):
 
125
        if self._push_back_buffer == '':
 
126
            raise AssertionError(
 
127
                '%s._push_back_buffer should never be the empty string, '
 
128
                'which can be confused with EOF' % (self,))
 
129
        bytes = self._push_back_buffer
 
130
        self._push_back_buffer = None
 
131
        return bytes
65
132
 
66
133
    def serve(self):
67
134
        """Serve requests until the client disconnects."""
85
152
 
86
153
        :returns: a SmartServerRequestProtocol.
87
154
        """
88
 
        # Identify the protocol version.
89
155
        bytes = self._get_line()
90
 
        if bytes.startswith(REQUEST_VERSION_TWO):
91
 
            protocol_class = SmartServerRequestProtocolTwo
92
 
            bytes = bytes[len(REQUEST_VERSION_TWO):]
93
 
        else:
94
 
            protocol_class = SmartServerRequestProtocolOne
95
 
        protocol = protocol_class(self.backing_transport, self._write_out)
96
 
        protocol.accept_bytes(bytes)
 
156
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
157
        protocol = protocol_factory(
 
158
            self.backing_transport, self._write_out, self.root_client_path)
 
159
        protocol.accept_bytes(unused_bytes)
97
160
        return protocol
98
161
 
99
162
    def _serve_one_request(self, protocol):
127
190
 
128
191
        :returns: a string of bytes ending in a newline (byte 0x0A).
129
192
        """
130
 
        # XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
131
 
        line = ''
132
 
        while not line or line[-1] != '\n':
133
 
            new_char = self._get_bytes(1)
134
 
            line += new_char
135
 
            if new_char == '':
 
193
        newline_pos = -1
 
194
        bytes = ''
 
195
        while newline_pos == -1:
 
196
            new_bytes = self._get_bytes(1)
 
197
            bytes += new_bytes
 
198
            if new_bytes == '':
136
199
                # Ran out of bytes before receiving a complete line.
137
 
                break
 
200
                return bytes
 
201
            newline_pos = bytes.find('\n')
 
202
        line = bytes[:newline_pos+1]
 
203
        self._push_back(bytes[newline_pos+1:])
138
204
        return line
139
 
 
 
205
 
140
206
 
141
207
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
142
208
 
143
 
    def __init__(self, sock, backing_transport):
 
209
    def __init__(self, sock, backing_transport, root_client_path='/'):
144
210
        """Constructor.
145
211
 
146
212
        :param sock: the socket the server will read from.  It will be put
147
213
            into blocking mode.
148
214
        """
149
 
        SmartServerStreamMedium.__init__(self, backing_transport)
150
 
        self.push_back = ''
 
215
        SmartServerStreamMedium.__init__(
 
216
            self, backing_transport, root_client_path=root_client_path)
151
217
        sock.setblocking(True)
152
218
        self.socket = sock
153
219
 
154
220
    def _serve_one_request_unguarded(self, protocol):
155
221
        while protocol.next_read_size():
156
 
            if self.push_back:
157
 
                protocol.accept_bytes(self.push_back)
158
 
                self.push_back = ''
159
 
            else:
160
 
                bytes = self._get_bytes(4096)
161
 
                if bytes == '':
162
 
                    self.finished = True
163
 
                    return
164
 
                protocol.accept_bytes(bytes)
 
222
            bytes = self._get_bytes(4096)
 
223
            if bytes == '':
 
224
                self.finished = True
 
225
                return
 
226
            protocol.accept_bytes(bytes)
165
227
        
166
 
        self.push_back = protocol.excess_buffer
 
228
        self._push_back(protocol.unused_data)
167
229
 
168
230
    def _get_bytes(self, desired_count):
 
231
        if self._push_back_buffer is not None:
 
232
            return self._get_push_back_buffer()
169
233
        # We ignore the desired_count because on sockets it's more efficient to
170
234
        # read 4k at a time.
171
235
        return self.socket.recv(4096)
172
236
    
173
237
    def terminate_due_to_error(self):
174
 
        """Called when an unhandled exception from the protocol occurs."""
175
238
        # TODO: This should log to a server log file, but no such thing
176
239
        # exists yet.  Andrew Bennetts 2006-09-29.
177
240
        self.socket.close()
217
280
            protocol.accept_bytes(bytes)
218
281
 
219
282
    def _get_bytes(self, desired_count):
 
283
        if self._push_back_buffer is not None:
 
284
            return self._get_push_back_buffer()
220
285
        return self._in.read(desired_count)
221
286
 
222
287
    def terminate_due_to_error(self):
368
433
class SmartClientMedium(object):
369
434
    """Smart client is a medium for sending smart protocol requests over."""
370
435
 
 
436
    def __init__(self, base):
 
437
        super(SmartClientMedium, self).__init__()
 
438
        self.base = base
 
439
        self._protocol_version_error = None
 
440
        self._protocol_version = None
 
441
        self._done_hello = False
 
442
        # Be optimistic: we assume the remote end can accept new remote
 
443
        # requests until we get an error saying otherwise.
 
444
        # _remote_version_is_before tracks the bzr version the remote side
 
445
        # can be based on what we've seen so far.
 
446
        self._remote_version_is_before = None
 
447
 
 
448
    def _is_remote_before(self, version_tuple):
 
449
        """Is it possible the remote side supports RPCs for a given version?
 
450
 
 
451
        Typical use::
 
452
 
 
453
            needed_version = (1, 2)
 
454
            if medium._is_remote_before(needed_version):
 
455
                fallback_to_pre_1_2_rpc()
 
456
            else:
 
457
                try:
 
458
                    do_1_2_rpc()
 
459
                except UnknownSmartMethod:
 
460
                    medium._remember_remote_is_before(needed_version)
 
461
                    fallback_to_pre_1_2_rpc()
 
462
 
 
463
        :seealso: _remember_remote_is_before
 
464
        """
 
465
        if self._remote_version_is_before is None:
 
466
            # So far, the remote side seems to support everything
 
467
            return False
 
468
        return version_tuple >= self._remote_version_is_before
 
469
 
 
470
    def _remember_remote_is_before(self, version_tuple):
 
471
        """Tell this medium that the remote side is older the given version.
 
472
 
 
473
        :seealso: _is_remote_before
 
474
        """
 
475
        if (self._remote_version_is_before is not None and
 
476
            version_tuple > self._remote_version_is_before):
 
477
            raise AssertionError(
 
478
                "_remember_remote_is_before(%r) called, but "
 
479
                "_remember_remote_is_before(%r) was called previously."
 
480
                % (version_tuple, self._remote_version_is_before))
 
481
        self._remote_version_is_before = version_tuple
 
482
 
 
483
    def protocol_version(self):
 
484
        """Find out if 'hello' smart request works."""
 
485
        if self._protocol_version_error is not None:
 
486
            raise self._protocol_version_error
 
487
        if not self._done_hello:
 
488
            try:
 
489
                medium_request = self.get_request()
 
490
                # Send a 'hello' request in protocol version one, for maximum
 
491
                # backwards compatibility.
 
492
                client_protocol = SmartClientRequestProtocolOne(medium_request)
 
493
                client_protocol.query_version()
 
494
                self._done_hello = True
 
495
            except errors.SmartProtocolError, e:
 
496
                # Cache the error, just like we would cache a successful
 
497
                # result.
 
498
                self._protocol_version_error = e
 
499
                raise
 
500
        return '2'
 
501
 
 
502
    def should_probe(self):
 
503
        """Should RemoteBzrDirFormat.probe_transport send a smart request on
 
504
        this medium?
 
505
 
 
506
        Some transports are unambiguously smart-only; there's no need to check
 
507
        if the transport is able to carry smart requests, because that's all
 
508
        it is for.  In those cases, this method should return False.
 
509
 
 
510
        But some HTTP transports can sometimes fail to carry smart requests,
 
511
        but still be usuable for accessing remote bzrdirs via plain file
 
512
        accesses.  So for those transports, their media should return True here
 
513
        so that RemoteBzrDirFormat can determine if it is appropriate for that
 
514
        transport.
 
515
        """
 
516
        return False
 
517
 
371
518
    def disconnect(self):
372
519
        """If this medium maintains a persistent connection, close it.
373
520
        
374
521
        The default implementation does nothing.
375
522
        """
376
523
        
 
524
    def remote_path_from_transport(self, transport):
 
525
        """Convert transport into a path suitable for using in a request.
 
526
        
 
527
        Note that the resulting remote path doesn't encode the host name or
 
528
        anything but path, so it is only safe to use it in requests sent over
 
529
        the medium from the matching transport.
 
530
        """
 
531
        medium_base = urlutils.join(self.base, '/')
 
532
        rel_url = urlutils.relative_url(medium_base, transport.base)
 
533
        return urllib.unquote(rel_url)
 
534
 
377
535
 
378
536
class SmartClientStreamMedium(SmartClientMedium):
379
537
    """Stream based medium common class.
384
542
    receive bytes.
385
543
    """
386
544
 
387
 
    def __init__(self):
 
545
    def __init__(self, base):
 
546
        SmartClientMedium.__init__(self, base)
388
547
        self._current_request = None
389
 
        # Be optimistic: we assume the remote end can accept new remote
390
 
        # requests until we get an error saying otherwise.  (1.2 adds some
391
 
        # requests that send bodies, which confuses older servers.)
392
 
        self._remote_is_at_least_1_2 = True
393
548
 
394
549
    def accept_bytes(self, bytes):
395
550
        self._accept_bytes(bytes)
426
581
    This client does not manage the pipes: it assumes they will always be open.
427
582
    """
428
583
 
429
 
    def __init__(self, readable_pipe, writeable_pipe):
430
 
        SmartClientStreamMedium.__init__(self)
 
584
    def __init__(self, readable_pipe, writeable_pipe, base):
 
585
        SmartClientStreamMedium.__init__(self, base)
431
586
        self._readable_pipe = readable_pipe
432
587
        self._writeable_pipe = writeable_pipe
433
588
 
448
603
    """A client medium using SSH."""
449
604
    
450
605
    def __init__(self, host, port=None, username=None, password=None,
451
 
            vendor=None, bzr_remote_path=None):
 
606
            base=None, vendor=None, bzr_remote_path=None):
452
607
        """Creates a client that will connect on the first use.
453
608
        
454
609
        :param vendor: An optional override for the ssh vendor to use. See
455
610
            bzrlib.transport.ssh for details on ssh vendors.
456
611
        """
457
 
        SmartClientStreamMedium.__init__(self)
 
612
        SmartClientStreamMedium.__init__(self, base)
458
613
        self._connected = False
459
614
        self._host = host
460
615
        self._password = password
520
675
class SmartTCPClientMedium(SmartClientStreamMedium):
521
676
    """A client medium using TCP."""
522
677
    
523
 
    def __init__(self, host, port):
 
678
    def __init__(self, host, port, base):
524
679
        """Creates a client that will connect on the first use."""
525
 
        SmartClientStreamMedium.__init__(self)
 
680
        SmartClientStreamMedium.__init__(self, base)
526
681
        self._connected = False
527
682
        self._host = host
528
683
        self._port = port
606
761
        This clears the _current_request on self._medium to allow a new 
607
762
        request to be created.
608
763
        """
609
 
        assert self._medium._current_request is self
 
764
        if self._medium._current_request is not self:
 
765
            raise AssertionError()
610
766
        self._medium._current_request = None
611
767
        
612
768
    def _finished_writing(self):