/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Vincent Ladeuil
  • Date: 2008-10-02 13:24:32 UTC
  • mto: This revision was merged to the branch mainline in revision 3760.
  • Revision ID: v.ladeuil+lp@free.fr-20081002132432-iwlhbyhmjgxbik99
Cleanups.

* bzrlib/tests/test_bundle.py: 
Fix module import order.
(TestReadMergeableFromUrl.test_smart_server_connection_reset): Add
comment.

* bzrlib/tests/__init__.py: 
(test_suite): Fix test module names order.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""The 'medium' layer for the smart servers and clients.
 
18
 
 
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
 
20
for "the quality between big and small."
 
21
 
 
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
 
23
over SSH), and pass them to and from the protocol logic.  See the overview in
 
24
bzrlib/transport/smart/__init__.py.
 
25
"""
 
26
 
 
27
import errno
 
28
import os
 
29
import socket
 
30
import sys
 
31
import urllib
 
32
 
 
33
from bzrlib.lazy_import import lazy_import
 
34
lazy_import(globals(), """
 
35
from bzrlib import (
 
36
    errors,
 
37
    osutils,
 
38
    symbol_versioning,
 
39
    urlutils,
 
40
    )
 
41
from bzrlib.smart import protocol
 
42
from bzrlib.transport import ssh
 
43
""")
 
44
 
 
45
 
 
46
# We must not read any more than 64k at a time so we don't risk "no buffer
 
47
# space available" errors on some platforms.  Windows in particular is likely
 
48
# to give error 10053 or 10055 if we read more than 64k from a socket.
 
49
_MAX_READ_SIZE = 64 * 1024
 
50
 
 
51
 
 
52
def _get_protocol_factory_for_bytes(bytes):
 
53
    """Determine the right protocol factory for 'bytes'.
 
54
 
 
55
    This will return an appropriate protocol factory depending on the version
 
56
    of the protocol being used, as determined by inspecting the given bytes.
 
57
    The bytes should have at least one newline byte (i.e. be a whole line),
 
58
    otherwise it's possible that a request will be incorrectly identified as
 
59
    version 1.
 
60
 
 
61
    Typical use would be::
 
62
 
 
63
         factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
64
         server_protocol = factory(transport, write_func, root_client_path)
 
65
         server_protocol.accept_bytes(unused_bytes)
 
66
 
 
67
    :param bytes: a str of bytes of the start of the request.
 
68
    :returns: 2-tuple of (protocol_factory, unused_bytes).  protocol_factory is
 
69
        a callable that takes three args: transport, write_func,
 
70
        root_client_path.  unused_bytes are any bytes that were not part of a
 
71
        protocol version marker.
 
72
    """
 
73
    if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
 
74
        protocol_factory = protocol.build_server_protocol_three
 
75
        bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
 
76
    elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
 
77
        protocol_factory = protocol.SmartServerRequestProtocolTwo
 
78
        bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
 
79
    else:
 
80
        protocol_factory = protocol.SmartServerRequestProtocolOne
 
81
    return protocol_factory, bytes
 
82
 
 
83
 
 
84
def _get_line(read_bytes_func):
 
85
    """Read bytes using read_bytes_func until a newline byte.
 
86
    
 
87
    This isn't particularly efficient, so should only be used when the
 
88
    expected size of the line is quite short.
 
89
    
 
90
    :returns: a tuple of two strs: (line, excess)
 
91
    """
 
92
    newline_pos = -1
 
93
    bytes = ''
 
94
    while newline_pos == -1:
 
95
        new_bytes = read_bytes_func(1)
 
96
        bytes += new_bytes
 
97
        if new_bytes == '':
 
98
            # Ran out of bytes before receiving a complete line.
 
99
            return bytes, ''
 
100
        newline_pos = bytes.find('\n')
 
101
    line = bytes[:newline_pos+1]
 
102
    excess = bytes[newline_pos+1:]
 
103
    return line, excess
 
104
 
 
105
 
 
106
class SmartMedium(object):
 
107
    """Base class for smart protocol media, both client- and server-side."""
 
108
 
 
109
    def __init__(self):
 
110
        self._push_back_buffer = None
 
111
        
 
112
    def _push_back(self, bytes):
 
113
        """Return unused bytes to the medium, because they belong to the next
 
114
        request(s).
 
115
 
 
116
        This sets the _push_back_buffer to the given bytes.
 
117
        """
 
118
        if self._push_back_buffer is not None:
 
119
            raise AssertionError(
 
120
                "_push_back called when self._push_back_buffer is %r"
 
121
                % (self._push_back_buffer,))
 
122
        if bytes == '':
 
123
            return
 
124
        self._push_back_buffer = bytes
 
125
 
 
126
    def _get_push_back_buffer(self):
 
127
        if self._push_back_buffer == '':
 
128
            raise AssertionError(
 
129
                '%s._push_back_buffer should never be the empty string, '
 
130
                'which can be confused with EOF' % (self,))
 
131
        bytes = self._push_back_buffer
 
132
        self._push_back_buffer = None
 
133
        return bytes
 
134
 
 
135
    def read_bytes(self, desired_count):
 
136
        """Read some bytes from this medium.
 
137
 
 
138
        :returns: some bytes, possibly more or less than the number requested
 
139
            in 'desired_count' depending on the medium.
 
140
        """
 
141
        if self._push_back_buffer is not None:
 
142
            return self._get_push_back_buffer()
 
143
        bytes_to_read = min(desired_count, _MAX_READ_SIZE)
 
144
        return self._read_bytes(bytes_to_read)
 
145
 
 
146
    def _read_bytes(self, count):
 
147
        raise NotImplementedError(self._read_bytes)
 
148
 
 
149
    def _get_line(self):
 
150
        """Read bytes from this request's response until a newline byte.
 
151
        
 
152
        This isn't particularly efficient, so should only be used when the
 
153
        expected size of the line is quite short.
 
154
 
 
155
        :returns: a string of bytes ending in a newline (byte 0x0A).
 
156
        """
 
157
        line, excess = _get_line(self.read_bytes)
 
158
        self._push_back(excess)
 
159
        return line
 
160
 
 
161
 
 
162
class SmartServerStreamMedium(SmartMedium):
 
163
    """Handles smart commands coming over a stream.
 
164
 
 
165
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
166
    in-process fifo for testing.
 
167
 
 
168
    One instance is created for each connected client; it can serve multiple
 
169
    requests in the lifetime of the connection.
 
170
 
 
171
    The server passes requests through to an underlying backing transport, 
 
172
    which will typically be a LocalTransport looking at the server's filesystem.
 
173
 
 
174
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
 
175
        but not used yet, or None if there are no buffered bytes.  Subclasses
 
176
        should make sure to exhaust this buffer before reading more bytes from
 
177
        the stream.  See also the _push_back method.
 
178
    """
 
179
 
 
180
    def __init__(self, backing_transport, root_client_path='/'):
 
181
        """Construct new server.
 
182
 
 
183
        :param backing_transport: Transport for the directory served.
 
184
        """
 
185
        # backing_transport could be passed to serve instead of __init__
 
186
        self.backing_transport = backing_transport
 
187
        self.root_client_path = root_client_path
 
188
        self.finished = False
 
189
        SmartMedium.__init__(self)
 
190
 
 
191
    def serve(self):
 
192
        """Serve requests until the client disconnects."""
 
193
        # Keep a reference to stderr because the sys module's globals get set to
 
194
        # None during interpreter shutdown.
 
195
        from sys import stderr
 
196
        try:
 
197
            while not self.finished:
 
198
                server_protocol = self._build_protocol()
 
199
                self._serve_one_request(server_protocol)
 
200
        except Exception, e:
 
201
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
202
            raise
 
203
 
 
204
    def _build_protocol(self):
 
205
        """Identifies the version of the incoming request, and returns an
 
206
        a protocol object that can interpret it.
 
207
 
 
208
        If more bytes than the version prefix of the request are read, they will
 
209
        be fed into the protocol before it is returned.
 
210
 
 
211
        :returns: a SmartServerRequestProtocol.
 
212
        """
 
213
        bytes = self._get_line()
 
214
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
215
        protocol = protocol_factory(
 
216
            self.backing_transport, self._write_out, self.root_client_path)
 
217
        protocol.accept_bytes(unused_bytes)
 
218
        return protocol
 
219
 
 
220
    def _serve_one_request(self, protocol):
 
221
        """Read one request from input, process, send back a response.
 
222
        
 
223
        :param protocol: a SmartServerRequestProtocol.
 
224
        """
 
225
        try:
 
226
            self._serve_one_request_unguarded(protocol)
 
227
        except KeyboardInterrupt:
 
228
            raise
 
229
        except Exception, e:
 
230
            self.terminate_due_to_error()
 
231
 
 
232
    def terminate_due_to_error(self):
 
233
        """Called when an unhandled exception from the protocol occurs."""
 
234
        raise NotImplementedError(self.terminate_due_to_error)
 
235
 
 
236
    def _read_bytes(self, desired_count):
 
237
        """Get some bytes from the medium.
 
238
 
 
239
        :param desired_count: number of bytes we want to read.
 
240
        """
 
241
        raise NotImplementedError(self._read_bytes)
 
242
 
 
243
 
 
244
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
245
 
 
246
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
247
        """Constructor.
 
248
 
 
249
        :param sock: the socket the server will read from.  It will be put
 
250
            into blocking mode.
 
251
        """
 
252
        SmartServerStreamMedium.__init__(
 
253
            self, backing_transport, root_client_path=root_client_path)
 
254
        sock.setblocking(True)
 
255
        self.socket = sock
 
256
 
 
257
    def _serve_one_request_unguarded(self, protocol):
 
258
        while protocol.next_read_size():
 
259
            # We can safely try to read large chunks.  If there is less data
 
260
            # than _MAX_READ_SIZE ready, the socket wil just return a short
 
261
            # read immediately rather than block.
 
262
            bytes = self.read_bytes(_MAX_READ_SIZE)
 
263
            if bytes == '':
 
264
                self.finished = True
 
265
                return
 
266
            protocol.accept_bytes(bytes)
 
267
        
 
268
        self._push_back(protocol.unused_data)
 
269
 
 
270
    def _read_bytes(self, desired_count):
 
271
        # We ignore the desired_count because on sockets it's more efficient to
 
272
        # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
273
        return self.socket.recv(_MAX_READ_SIZE)
 
274
 
 
275
    def terminate_due_to_error(self):
 
276
        # TODO: This should log to a server log file, but no such thing
 
277
        # exists yet.  Andrew Bennetts 2006-09-29.
 
278
        self.socket.close()
 
279
        self.finished = True
 
280
 
 
281
    def _write_out(self, bytes):
 
282
        osutils.send_all(self.socket, bytes)
 
283
 
 
284
 
 
285
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
286
 
 
287
    def __init__(self, in_file, out_file, backing_transport):
 
288
        """Construct new server.
 
289
 
 
290
        :param in_file: Python file from which requests can be read.
 
291
        :param out_file: Python file to write responses.
 
292
        :param backing_transport: Transport for the directory served.
 
293
        """
 
294
        SmartServerStreamMedium.__init__(self, backing_transport)
 
295
        if sys.platform == 'win32':
 
296
            # force binary mode for files
 
297
            import msvcrt
 
298
            for f in (in_file, out_file):
 
299
                fileno = getattr(f, 'fileno', None)
 
300
                if fileno:
 
301
                    msvcrt.setmode(fileno(), os.O_BINARY)
 
302
        self._in = in_file
 
303
        self._out = out_file
 
304
 
 
305
    def _serve_one_request_unguarded(self, protocol):
 
306
        while True:
 
307
            # We need to be careful not to read past the end of the current
 
308
            # request, or else the read from the pipe will block, so we use
 
309
            # protocol.next_read_size().
 
310
            bytes_to_read = protocol.next_read_size()
 
311
            if bytes_to_read == 0:
 
312
                # Finished serving this request.
 
313
                self._out.flush()
 
314
                return
 
315
            bytes = self.read_bytes(bytes_to_read)
 
316
            if bytes == '':
 
317
                # Connection has been closed.
 
318
                self.finished = True
 
319
                self._out.flush()
 
320
                return
 
321
            protocol.accept_bytes(bytes)
 
322
 
 
323
    def _read_bytes(self, desired_count):
 
324
        return self._in.read(desired_count)
 
325
 
 
326
    def terminate_due_to_error(self):
 
327
        # TODO: This should log to a server log file, but no such thing
 
328
        # exists yet.  Andrew Bennetts 2006-09-29.
 
329
        self._out.close()
 
330
        self.finished = True
 
331
 
 
332
    def _write_out(self, bytes):
 
333
        self._out.write(bytes)
 
334
 
 
335
 
 
336
class SmartClientMediumRequest(object):
 
337
    """A request on a SmartClientMedium.
 
338
 
 
339
    Each request allows bytes to be provided to it via accept_bytes, and then
 
340
    the response bytes to be read via read_bytes.
 
341
 
 
342
    For instance:
 
343
    request.accept_bytes('123')
 
344
    request.finished_writing()
 
345
    result = request.read_bytes(3)
 
346
    request.finished_reading()
 
347
 
 
348
    It is up to the individual SmartClientMedium whether multiple concurrent
 
349
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
350
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
351
    details on concurrency and pipelining.
 
352
    """
 
353
 
 
354
    def __init__(self, medium):
 
355
        """Construct a SmartClientMediumRequest for the medium medium."""
 
356
        self._medium = medium
 
357
        # we track state by constants - we may want to use the same
 
358
        # pattern as BodyReader if it gets more complex.
 
359
        # valid states are: "writing", "reading", "done"
 
360
        self._state = "writing"
 
361
 
 
362
    def accept_bytes(self, bytes):
 
363
        """Accept bytes for inclusion in this request.
 
364
 
 
365
        This method may not be be called after finished_writing() has been
 
366
        called.  It depends upon the Medium whether or not the bytes will be
 
367
        immediately transmitted. Message based Mediums will tend to buffer the
 
368
        bytes until finished_writing() is called.
 
369
 
 
370
        :param bytes: A bytestring.
 
371
        """
 
372
        if self._state != "writing":
 
373
            raise errors.WritingCompleted(self)
 
374
        self._accept_bytes(bytes)
 
375
 
 
376
    def _accept_bytes(self, bytes):
 
377
        """Helper for accept_bytes.
 
378
 
 
379
        Accept_bytes checks the state of the request to determing if bytes
 
380
        should be accepted. After that it hands off to _accept_bytes to do the
 
381
        actual acceptance.
 
382
        """
 
383
        raise NotImplementedError(self._accept_bytes)
 
384
 
 
385
    def finished_reading(self):
 
386
        """Inform the request that all desired data has been read.
 
387
 
 
388
        This will remove the request from the pipeline for its medium (if the
 
389
        medium supports pipelining) and any further calls to methods on the
 
390
        request will raise ReadingCompleted.
 
391
        """
 
392
        if self._state == "writing":
 
393
            raise errors.WritingNotComplete(self)
 
394
        if self._state != "reading":
 
395
            raise errors.ReadingCompleted(self)
 
396
        self._state = "done"
 
397
        self._finished_reading()
 
398
 
 
399
    def _finished_reading(self):
 
400
        """Helper for finished_reading.
 
401
 
 
402
        finished_reading checks the state of the request to determine if 
 
403
        finished_reading is allowed, and if it is hands off to _finished_reading
 
404
        to perform the action.
 
405
        """
 
406
        raise NotImplementedError(self._finished_reading)
 
407
 
 
408
    def finished_writing(self):
 
409
        """Finish the writing phase of this request.
 
410
 
 
411
        This will flush all pending data for this request along the medium.
 
412
        After calling finished_writing, you may not call accept_bytes anymore.
 
413
        """
 
414
        if self._state != "writing":
 
415
            raise errors.WritingCompleted(self)
 
416
        self._state = "reading"
 
417
        self._finished_writing()
 
418
 
 
419
    def _finished_writing(self):
 
420
        """Helper for finished_writing.
 
421
 
 
422
        finished_writing checks the state of the request to determine if 
 
423
        finished_writing is allowed, and if it is hands off to _finished_writing
 
424
        to perform the action.
 
425
        """
 
426
        raise NotImplementedError(self._finished_writing)
 
427
 
 
428
    def read_bytes(self, count):
 
429
        """Read bytes from this requests response.
 
430
 
 
431
        This method will block and wait for count bytes to be read. It may not
 
432
        be invoked until finished_writing() has been called - this is to ensure
 
433
        a message-based approach to requests, for compatibility with message
 
434
        based mediums like HTTP.
 
435
        """
 
436
        if self._state == "writing":
 
437
            raise errors.WritingNotComplete(self)
 
438
        if self._state != "reading":
 
439
            raise errors.ReadingCompleted(self)
 
440
        return self._read_bytes(count)
 
441
 
 
442
    def _read_bytes(self, count):
 
443
        """Helper for SmartClientMediumRequest.read_bytes.
 
444
 
 
445
        read_bytes checks the state of the request to determing if bytes
 
446
        should be read. After that it hands off to _read_bytes to do the
 
447
        actual read.
 
448
        
 
449
        By default this forwards to self._medium.read_bytes because we are
 
450
        operating on the medium's stream.
 
451
        """
 
452
        return self._medium.read_bytes(count)
 
453
 
 
454
    def read_line(self):
 
455
        line = self._read_line()
 
456
        if not line.endswith('\n'):
 
457
            # end of file encountered reading from server
 
458
            raise errors.ConnectionReset(
 
459
                "please check connectivity and permissions",
 
460
                "(and try -Dhpss if further diagnosis is required)")
 
461
        return line
 
462
 
 
463
    def _read_line(self):
 
464
        """Helper for SmartClientMediumRequest.read_line.
 
465
        
 
466
        By default this forwards to self._medium._get_line because we are
 
467
        operating on the medium's stream.
 
468
        """
 
469
        return self._medium._get_line()
 
470
 
 
471
 
 
472
class SmartClientMedium(SmartMedium):
 
473
    """Smart client is a medium for sending smart protocol requests over."""
 
474
 
 
475
    def __init__(self, base):
 
476
        super(SmartClientMedium, self).__init__()
 
477
        self.base = base
 
478
        self._protocol_version_error = None
 
479
        self._protocol_version = None
 
480
        self._done_hello = False
 
481
        # Be optimistic: we assume the remote end can accept new remote
 
482
        # requests until we get an error saying otherwise.
 
483
        # _remote_version_is_before tracks the bzr version the remote side
 
484
        # can be based on what we've seen so far.
 
485
        self._remote_version_is_before = None
 
486
 
 
487
    def _is_remote_before(self, version_tuple):
 
488
        """Is it possible the remote side supports RPCs for a given version?
 
489
 
 
490
        Typical use::
 
491
 
 
492
            needed_version = (1, 2)
 
493
            if medium._is_remote_before(needed_version):
 
494
                fallback_to_pre_1_2_rpc()
 
495
            else:
 
496
                try:
 
497
                    do_1_2_rpc()
 
498
                except UnknownSmartMethod:
 
499
                    medium._remember_remote_is_before(needed_version)
 
500
                    fallback_to_pre_1_2_rpc()
 
501
 
 
502
        :seealso: _remember_remote_is_before
 
503
        """
 
504
        if self._remote_version_is_before is None:
 
505
            # So far, the remote side seems to support everything
 
506
            return False
 
507
        return version_tuple >= self._remote_version_is_before
 
508
 
 
509
    def _remember_remote_is_before(self, version_tuple):
 
510
        """Tell this medium that the remote side is older the given version.
 
511
 
 
512
        :seealso: _is_remote_before
 
513
        """
 
514
        if (self._remote_version_is_before is not None and
 
515
            version_tuple > self._remote_version_is_before):
 
516
            raise AssertionError(
 
517
                "_remember_remote_is_before(%r) called, but "
 
518
                "_remember_remote_is_before(%r) was called previously."
 
519
                % (version_tuple, self._remote_version_is_before))
 
520
        self._remote_version_is_before = version_tuple
 
521
 
 
522
    def protocol_version(self):
 
523
        """Find out if 'hello' smart request works."""
 
524
        if self._protocol_version_error is not None:
 
525
            raise self._protocol_version_error
 
526
        if not self._done_hello:
 
527
            try:
 
528
                medium_request = self.get_request()
 
529
                # Send a 'hello' request in protocol version one, for maximum
 
530
                # backwards compatibility.
 
531
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
 
532
                client_protocol.query_version()
 
533
                self._done_hello = True
 
534
            except errors.SmartProtocolError, e:
 
535
                # Cache the error, just like we would cache a successful
 
536
                # result.
 
537
                self._protocol_version_error = e
 
538
                raise
 
539
        return '2'
 
540
 
 
541
    def should_probe(self):
 
542
        """Should RemoteBzrDirFormat.probe_transport send a smart request on
 
543
        this medium?
 
544
 
 
545
        Some transports are unambiguously smart-only; there's no need to check
 
546
        if the transport is able to carry smart requests, because that's all
 
547
        it is for.  In those cases, this method should return False.
 
548
 
 
549
        But some HTTP transports can sometimes fail to carry smart requests,
 
550
        but still be usuable for accessing remote bzrdirs via plain file
 
551
        accesses.  So for those transports, their media should return True here
 
552
        so that RemoteBzrDirFormat can determine if it is appropriate for that
 
553
        transport.
 
554
        """
 
555
        return False
 
556
 
 
557
    def disconnect(self):
 
558
        """If this medium maintains a persistent connection, close it.
 
559
        
 
560
        The default implementation does nothing.
 
561
        """
 
562
        
 
563
    def remote_path_from_transport(self, transport):
 
564
        """Convert transport into a path suitable for using in a request.
 
565
        
 
566
        Note that the resulting remote path doesn't encode the host name or
 
567
        anything but path, so it is only safe to use it in requests sent over
 
568
        the medium from the matching transport.
 
569
        """
 
570
        medium_base = urlutils.join(self.base, '/')
 
571
        rel_url = urlutils.relative_url(medium_base, transport.base)
 
572
        return urllib.unquote(rel_url)
 
573
 
 
574
 
 
575
class SmartClientStreamMedium(SmartClientMedium):
 
576
    """Stream based medium common class.
 
577
 
 
578
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
579
    SmartClientStreamMediumRequest for their requests, and should implement
 
580
    _accept_bytes and _read_bytes to allow the request objects to send and
 
581
    receive bytes.
 
582
    """
 
583
 
 
584
    def __init__(self, base):
 
585
        SmartClientMedium.__init__(self, base)
 
586
        self._current_request = None
 
587
 
 
588
    def accept_bytes(self, bytes):
 
589
        self._accept_bytes(bytes)
 
590
 
 
591
    def __del__(self):
 
592
        """The SmartClientStreamMedium knows how to close the stream when it is
 
593
        finished with it.
 
594
        """
 
595
        self.disconnect()
 
596
 
 
597
    def _flush(self):
 
598
        """Flush the output stream.
 
599
        
 
600
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
601
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
602
        """
 
603
        raise NotImplementedError(self._flush)
 
604
 
 
605
    def get_request(self):
 
606
        """See SmartClientMedium.get_request().
 
607
 
 
608
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
609
        for get_request.
 
610
        """
 
611
        return SmartClientStreamMediumRequest(self)
 
612
 
 
613
 
 
614
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
615
    """A client medium using simple pipes.
 
616
    
 
617
    This client does not manage the pipes: it assumes they will always be open.
 
618
    """
 
619
 
 
620
    def __init__(self, readable_pipe, writeable_pipe, base):
 
621
        SmartClientStreamMedium.__init__(self, base)
 
622
        self._readable_pipe = readable_pipe
 
623
        self._writeable_pipe = writeable_pipe
 
624
 
 
625
    def _accept_bytes(self, bytes):
 
626
        """See SmartClientStreamMedium.accept_bytes."""
 
627
        self._writeable_pipe.write(bytes)
 
628
 
 
629
    def _flush(self):
 
630
        """See SmartClientStreamMedium._flush()."""
 
631
        self._writeable_pipe.flush()
 
632
 
 
633
    def _read_bytes(self, count):
 
634
        """See SmartClientStreamMedium._read_bytes."""
 
635
        return self._readable_pipe.read(count)
 
636
 
 
637
 
 
638
class SmartSSHClientMedium(SmartClientStreamMedium):
 
639
    """A client medium using SSH."""
 
640
    
 
641
    def __init__(self, host, port=None, username=None, password=None,
 
642
            base=None, vendor=None, bzr_remote_path=None):
 
643
        """Creates a client that will connect on the first use.
 
644
        
 
645
        :param vendor: An optional override for the ssh vendor to use. See
 
646
            bzrlib.transport.ssh for details on ssh vendors.
 
647
        """
 
648
        SmartClientStreamMedium.__init__(self, base)
 
649
        self._connected = False
 
650
        self._host = host
 
651
        self._password = password
 
652
        self._port = port
 
653
        self._username = username
 
654
        self._read_from = None
 
655
        self._ssh_connection = None
 
656
        self._vendor = vendor
 
657
        self._write_to = None
 
658
        self._bzr_remote_path = bzr_remote_path
 
659
        if self._bzr_remote_path is None:
 
660
            symbol_versioning.warn(
 
661
                'bzr_remote_path is required as of bzr 0.92',
 
662
                DeprecationWarning, stacklevel=2)
 
663
            self._bzr_remote_path = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
664
 
 
665
    def _accept_bytes(self, bytes):
 
666
        """See SmartClientStreamMedium.accept_bytes."""
 
667
        self._ensure_connection()
 
668
        self._write_to.write(bytes)
 
669
 
 
670
    def disconnect(self):
 
671
        """See SmartClientMedium.disconnect()."""
 
672
        if not self._connected:
 
673
            return
 
674
        self._read_from.close()
 
675
        self._write_to.close()
 
676
        self._ssh_connection.close()
 
677
        self._connected = False
 
678
 
 
679
    def _ensure_connection(self):
 
680
        """Connect this medium if not already connected."""
 
681
        if self._connected:
 
682
            return
 
683
        if self._vendor is None:
 
684
            vendor = ssh._get_ssh_vendor()
 
685
        else:
 
686
            vendor = self._vendor
 
687
        self._ssh_connection = vendor.connect_ssh(self._username,
 
688
                self._password, self._host, self._port,
 
689
                command=[self._bzr_remote_path, 'serve', '--inet',
 
690
                         '--directory=/', '--allow-writes'])
 
691
        self._read_from, self._write_to = \
 
692
            self._ssh_connection.get_filelike_channels()
 
693
        self._connected = True
 
694
 
 
695
    def _flush(self):
 
696
        """See SmartClientStreamMedium._flush()."""
 
697
        self._write_to.flush()
 
698
 
 
699
    def _read_bytes(self, count):
 
700
        """See SmartClientStreamMedium.read_bytes."""
 
701
        if not self._connected:
 
702
            raise errors.MediumNotConnected(self)
 
703
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
704
        return self._read_from.read(bytes_to_read)
 
705
 
 
706
 
 
707
# Port 4155 is the default port for bzr://, registered with IANA.
 
708
BZR_DEFAULT_INTERFACE = None
 
709
BZR_DEFAULT_PORT = 4155
 
710
 
 
711
 
 
712
class SmartTCPClientMedium(SmartClientStreamMedium):
 
713
    """A client medium using TCP."""
 
714
    
 
715
    def __init__(self, host, port, base):
 
716
        """Creates a client that will connect on the first use."""
 
717
        SmartClientStreamMedium.__init__(self, base)
 
718
        self._connected = False
 
719
        self._host = host
 
720
        self._port = port
 
721
        self._socket = None
 
722
 
 
723
    def _accept_bytes(self, bytes):
 
724
        """See SmartClientMedium.accept_bytes."""
 
725
        self._ensure_connection()
 
726
        osutils.send_all(self._socket, bytes)
 
727
 
 
728
    def disconnect(self):
 
729
        """See SmartClientMedium.disconnect()."""
 
730
        if not self._connected:
 
731
            return
 
732
        self._socket.close()
 
733
        self._socket = None
 
734
        self._connected = False
 
735
 
 
736
    def _ensure_connection(self):
 
737
        """Connect this medium if not already connected."""
 
738
        if self._connected:
 
739
            return
 
740
        if self._port is None:
 
741
            port = BZR_DEFAULT_PORT
 
742
        else:
 
743
            port = int(self._port)
 
744
        try:
 
745
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC, 
 
746
                socket.SOCK_STREAM, 0, 0)
 
747
        except socket.gaierror, (err_num, err_msg):
 
748
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
 
749
                    (self._host, port, err_msg))
 
750
        # Initialize err in case there are no addresses returned:
 
751
        err = socket.error("no address found for %s" % self._host)
 
752
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
 
753
            try:
 
754
                self._socket = socket.socket(family, socktype, proto)
 
755
                self._socket.setsockopt(socket.IPPROTO_TCP, 
 
756
                                        socket.TCP_NODELAY, 1)
 
757
                self._socket.connect(sockaddr)
 
758
            except socket.error, err:
 
759
                if self._socket is not None:
 
760
                    self._socket.close()
 
761
                self._socket = None
 
762
                continue
 
763
            break
 
764
        if self._socket is None:
 
765
            # socket errors either have a (string) or (errno, string) as their
 
766
            # args.
 
767
            if type(err.args) is str:
 
768
                err_msg = err.args
 
769
            else:
 
770
                err_msg = err.args[1]
 
771
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
772
                    (self._host, port, err_msg))
 
773
        self._connected = True
 
774
 
 
775
    def _flush(self):
 
776
        """See SmartClientStreamMedium._flush().
 
777
        
 
778
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
779
        add a means to do a flush, but that can be done in the future.
 
780
        """
 
781
 
 
782
    def _read_bytes(self, count):
 
783
        """See SmartClientMedium.read_bytes."""
 
784
        if not self._connected:
 
785
            raise errors.MediumNotConnected(self)
 
786
        # We ignore the desired_count because on sockets it's more efficient to
 
787
        # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
788
        try:
 
789
            return self._socket.recv(_MAX_READ_SIZE)
 
790
        except socket.error, e:
 
791
            if len(e.args) and e.args[0] == errno.ECONNRESET:
 
792
                # Callers expect an empty string in that case
 
793
                return ''
 
794
            else:
 
795
                raise
 
796
 
 
797
 
 
798
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
799
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
800
 
 
801
    def __init__(self, medium):
 
802
        SmartClientMediumRequest.__init__(self, medium)
 
803
        # check that we are safe concurrency wise. If some streams start
 
804
        # allowing concurrent requests - i.e. via multiplexing - then this
 
805
        # assert should be moved to SmartClientStreamMedium.get_request,
 
806
        # and the setting/unsetting of _current_request likewise moved into
 
807
        # that class : but its unneeded overhead for now. RBC 20060922
 
808
        if self._medium._current_request is not None:
 
809
            raise errors.TooManyConcurrentRequests(self._medium)
 
810
        self._medium._current_request = self
 
811
 
 
812
    def _accept_bytes(self, bytes):
 
813
        """See SmartClientMediumRequest._accept_bytes.
 
814
        
 
815
        This forwards to self._medium._accept_bytes because we are operating
 
816
        on the mediums stream.
 
817
        """
 
818
        self._medium._accept_bytes(bytes)
 
819
 
 
820
    def _finished_reading(self):
 
821
        """See SmartClientMediumRequest._finished_reading.
 
822
 
 
823
        This clears the _current_request on self._medium to allow a new 
 
824
        request to be created.
 
825
        """
 
826
        if self._medium._current_request is not self:
 
827
            raise AssertionError()
 
828
        self._medium._current_request = None
 
829
        
 
830
    def _finished_writing(self):
 
831
        """See SmartClientMediumRequest._finished_writing.
 
832
 
 
833
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
834
        """
 
835
        self._medium._flush()
 
836