/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Andrew Bennetts
  • Date: 2008-07-21 04:24:21 UTC
  • mto: This revision was merged to the branch mainline in revision 3568.
  • Revision ID: andrew.bennetts@canonical.com-20080721042421-63lh85e76o57jch4
Read no more then 64k at a time in the smart protocol code.

The logic for this has been moved entirely into bzrlib.smart.medium, and
duplication (both in that module, and in bzrlib.smart.protocol) has been mostly
refactored out.  In particular there's now a SmartMedium base class used for
both client- and server-side media, and only one place that reading a line is
implemented.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""The 'medium' layer for the smart servers and clients.
 
18
 
 
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
 
20
for "the quality between big and small."
 
21
 
 
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
 
23
over SSH), and pass them to and from the protocol logic.  See the overview in
 
24
bzrlib/transport/smart/__init__.py.
 
25
"""
 
26
 
 
27
import os
 
28
import socket
 
29
import sys
 
30
import urllib
 
31
 
 
32
from bzrlib.lazy_import import lazy_import
 
33
lazy_import(globals(), """
 
34
from bzrlib import (
 
35
    errors,
 
36
    osutils,
 
37
    symbol_versioning,
 
38
    urlutils,
 
39
    )
 
40
from bzrlib.smart import protocol
 
41
from bzrlib.transport import ssh
 
42
""")
 
43
 
 
44
 
 
45
def _get_protocol_factory_for_bytes(bytes):
 
46
    """Determine the right protocol factory for 'bytes'.
 
47
 
 
48
    This will return an appropriate protocol factory depending on the version
 
49
    of the protocol being used, as determined by inspecting the given bytes.
 
50
    The bytes should have at least one newline byte (i.e. be a whole line),
 
51
    otherwise it's possible that a request will be incorrectly identified as
 
52
    version 1.
 
53
 
 
54
    Typical use would be::
 
55
 
 
56
         factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
57
         server_protocol = factory(transport, write_func, root_client_path)
 
58
         server_protocol.accept_bytes(unused_bytes)
 
59
 
 
60
    :param bytes: a str of bytes of the start of the request.
 
61
    :returns: 2-tuple of (protocol_factory, unused_bytes).  protocol_factory is
 
62
        a callable that takes three args: transport, write_func,
 
63
        root_client_path.  unused_bytes are any bytes that were not part of a
 
64
        protocol version marker.
 
65
    """
 
66
    if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
 
67
        protocol_factory = protocol.build_server_protocol_three
 
68
        bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
 
69
    elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
 
70
        protocol_factory = protocol.SmartServerRequestProtocolTwo
 
71
        bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
 
72
    else:
 
73
        protocol_factory = protocol.SmartServerRequestProtocolOne
 
74
    return protocol_factory, bytes
 
75
 
 
76
 
 
77
class SmartMedium(object):
 
78
    """Base class for smart protocol media, both client- and server-side."""
 
79
 
 
80
    def __init__(self):
 
81
        self._push_back_buffer = None
 
82
        
 
83
    def _push_back(self, bytes):
 
84
        """Return unused bytes to the medium, because they belong to the next
 
85
        request(s).
 
86
 
 
87
        This sets the _push_back_buffer to the given bytes.
 
88
        """
 
89
        if self._push_back_buffer is not None:
 
90
            raise AssertionError(
 
91
                "_push_back called when self._push_back_buffer is %r"
 
92
                % (self._push_back_buffer,))
 
93
        if bytes == '':
 
94
            return
 
95
        self._push_back_buffer = bytes
 
96
 
 
97
    def _get_push_back_buffer(self):
 
98
        if self._push_back_buffer == '':
 
99
            raise AssertionError(
 
100
                '%s._push_back_buffer should never be the empty string, '
 
101
                'which can be confused with EOF' % (self,))
 
102
        bytes = self._push_back_buffer
 
103
        self._push_back_buffer = None
 
104
        return bytes
 
105
 
 
106
    def read_bytes(self, desired_count):
 
107
        max_read = 64 * 1024
 
108
        bytes_to_read = min(count, max_read)
 
109
        return self._read_bytes(bytes_to_read)
 
110
 
 
111
    def _read_bytes(self, count):
 
112
        raise NotImplementedError(self._read_bytes)
 
113
 
 
114
    def _get_line(self):
 
115
        """Read bytes from this request's response until a newline byte.
 
116
        
 
117
        This isn't particularly efficient, so should only be used when the
 
118
        expected size of the line is quite short.
 
119
 
 
120
        :returns: a string of bytes ending in a newline (byte 0x0A).
 
121
        """
 
122
        newline_pos = -1
 
123
        bytes = ''
 
124
        while newline_pos == -1:
 
125
            new_bytes = self._read_bytes(1)
 
126
            bytes += new_bytes
 
127
            if new_bytes == '':
 
128
                # Ran out of bytes before receiving a complete line.
 
129
                return bytes
 
130
            newline_pos = bytes.find('\n')
 
131
        line = bytes[:newline_pos+1]
 
132
        self._push_back(bytes[newline_pos+1:])
 
133
        return line
 
134
 
 
135
 
 
136
class SmartServerStreamMedium(SmartMedium):
 
137
    """Handles smart commands coming over a stream.
 
138
 
 
139
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
140
    in-process fifo for testing.
 
141
 
 
142
    One instance is created for each connected client; it can serve multiple
 
143
    requests in the lifetime of the connection.
 
144
 
 
145
    The server passes requests through to an underlying backing transport, 
 
146
    which will typically be a LocalTransport looking at the server's filesystem.
 
147
 
 
148
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
 
149
        but not used yet, or None if there are no buffered bytes.  Subclasses
 
150
        should make sure to exhaust this buffer before reading more bytes from
 
151
        the stream.  See also the _push_back method.
 
152
    """
 
153
 
 
154
    def __init__(self, backing_transport, root_client_path='/'):
 
155
        """Construct new server.
 
156
 
 
157
        :param backing_transport: Transport for the directory served.
 
158
        """
 
159
        # backing_transport could be passed to serve instead of __init__
 
160
        self.backing_transport = backing_transport
 
161
        self.root_client_path = root_client_path
 
162
        self.finished = False
 
163
        SmartMedium.__init__(self)
 
164
 
 
165
    def serve(self):
 
166
        """Serve requests until the client disconnects."""
 
167
        # Keep a reference to stderr because the sys module's globals get set to
 
168
        # None during interpreter shutdown.
 
169
        from sys import stderr
 
170
        try:
 
171
            while not self.finished:
 
172
                server_protocol = self._build_protocol()
 
173
                self._serve_one_request(server_protocol)
 
174
        except Exception, e:
 
175
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
176
            raise
 
177
 
 
178
    def _build_protocol(self):
 
179
        """Identifies the version of the incoming request, and returns an
 
180
        a protocol object that can interpret it.
 
181
 
 
182
        If more bytes than the version prefix of the request are read, they will
 
183
        be fed into the protocol before it is returned.
 
184
 
 
185
        :returns: a SmartServerRequestProtocol.
 
186
        """
 
187
        bytes = self._get_line()
 
188
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
189
        protocol = protocol_factory(
 
190
            self.backing_transport, self._write_out, self.root_client_path)
 
191
        protocol.accept_bytes(unused_bytes)
 
192
        return protocol
 
193
 
 
194
    def _serve_one_request(self, protocol):
 
195
        """Read one request from input, process, send back a response.
 
196
        
 
197
        :param protocol: a SmartServerRequestProtocol.
 
198
        """
 
199
        try:
 
200
            self._serve_one_request_unguarded(protocol)
 
201
        except KeyboardInterrupt:
 
202
            raise
 
203
        except Exception, e:
 
204
            self.terminate_due_to_error()
 
205
 
 
206
    def terminate_due_to_error(self):
 
207
        """Called when an unhandled exception from the protocol occurs."""
 
208
        raise NotImplementedError(self.terminate_due_to_error)
 
209
 
 
210
    def _get_bytes(self, desired_count):
 
211
        """Get some bytes from the medium.
 
212
 
 
213
        :param desired_count: number of bytes we want to read.
 
214
        """
 
215
        raise NotImplementedError(self._get_bytes)
 
216
 
 
217
 
 
218
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
219
 
 
220
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
221
        """Constructor.
 
222
 
 
223
        :param sock: the socket the server will read from.  It will be put
 
224
            into blocking mode.
 
225
        """
 
226
        SmartServerStreamMedium.__init__(
 
227
            self, backing_transport, root_client_path=root_client_path)
 
228
        sock.setblocking(True)
 
229
        self.socket = sock
 
230
 
 
231
    def _serve_one_request_unguarded(self, protocol):
 
232
        while protocol.next_read_size():
 
233
            bytes = self._get_bytes(4096)
 
234
            if bytes == '':
 
235
                self.finished = True
 
236
                return
 
237
            protocol.accept_bytes(bytes)
 
238
        
 
239
        self._push_back(protocol.unused_data)
 
240
 
 
241
    def _get_bytes(self, desired_count):
 
242
        if self._push_back_buffer is not None:
 
243
            return self._get_push_back_buffer()
 
244
        # We ignore the desired_count because on sockets it's more efficient to
 
245
        # read 64k at a time.  Also, we must not read any more than 64k at a
 
246
        # time so that we don't risk error 10053 or 10055 on Windows (no buffer
 
247
        # space available).
 
248
        return self.socket.recv(64 * 1024)
 
249
 
 
250
    # XXX: duplication
 
251
    def _read_bytes(self, count):
 
252
        return self._get_bytes(count)
 
253
    
 
254
    def terminate_due_to_error(self):
 
255
        # TODO: This should log to a server log file, but no such thing
 
256
        # exists yet.  Andrew Bennetts 2006-09-29.
 
257
        self.socket.close()
 
258
        self.finished = True
 
259
 
 
260
    def _write_out(self, bytes):
 
261
        osutils.send_all(self.socket, bytes)
 
262
 
 
263
 
 
264
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
265
 
 
266
    def __init__(self, in_file, out_file, backing_transport):
 
267
        """Construct new server.
 
268
 
 
269
        :param in_file: Python file from which requests can be read.
 
270
        :param out_file: Python file to write responses.
 
271
        :param backing_transport: Transport for the directory served.
 
272
        """
 
273
        SmartServerStreamMedium.__init__(self, backing_transport)
 
274
        if sys.platform == 'win32':
 
275
            # force binary mode for files
 
276
            import msvcrt
 
277
            for f in (in_file, out_file):
 
278
                fileno = getattr(f, 'fileno', None)
 
279
                if fileno:
 
280
                    msvcrt.setmode(fileno(), os.O_BINARY)
 
281
        self._in = in_file
 
282
        self._out = out_file
 
283
 
 
284
    def _serve_one_request_unguarded(self, protocol):
 
285
        while True:
 
286
            bytes_to_read = protocol.next_read_size()
 
287
            if bytes_to_read == 0:
 
288
                # Finished serving this request.
 
289
                self._out.flush()
 
290
                return
 
291
            bytes = self._get_bytes(bytes_to_read)
 
292
            if bytes == '':
 
293
                # Connection has been closed.
 
294
                self.finished = True
 
295
                self._out.flush()
 
296
                return
 
297
            protocol.accept_bytes(bytes)
 
298
 
 
299
    def _get_bytes(self, desired_count):
 
300
        if self._push_back_buffer is not None:
 
301
            return self._get_push_back_buffer()
 
302
        return self._in.read(desired_count)
 
303
 
 
304
    # XXX: duplication
 
305
    def _read_bytes(self, count):
 
306
        return self._get_bytes(count)
 
307
    
 
308
    def terminate_due_to_error(self):
 
309
        # TODO: This should log to a server log file, but no such thing
 
310
        # exists yet.  Andrew Bennetts 2006-09-29.
 
311
        self._out.close()
 
312
        self.finished = True
 
313
 
 
314
    def _write_out(self, bytes):
 
315
        self._out.write(bytes)
 
316
 
 
317
 
 
318
class SmartClientMediumRequest(SmartMedium):
 
319
    """A request on a SmartClientMedium.
 
320
 
 
321
    Each request allows bytes to be provided to it via accept_bytes, and then
 
322
    the response bytes to be read via read_bytes.
 
323
 
 
324
    For instance:
 
325
    request.accept_bytes('123')
 
326
    request.finished_writing()
 
327
    result = request.read_bytes(3)
 
328
    request.finished_reading()
 
329
 
 
330
    It is up to the individual SmartClientMedium whether multiple concurrent
 
331
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
332
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
333
    details on concurrency and pipelining.
 
334
    """
 
335
 
 
336
    def __init__(self, medium):
 
337
        """Construct a SmartClientMediumRequest for the medium medium."""
 
338
        self._medium = medium
 
339
        # we track state by constants - we may want to use the same
 
340
        # pattern as BodyReader if it gets more complex.
 
341
        # valid states are: "writing", "reading", "done"
 
342
        self._state = "writing"
 
343
 
 
344
    def accept_bytes(self, bytes):
 
345
        """Accept bytes for inclusion in this request.
 
346
 
 
347
        This method may not be be called after finished_writing() has been
 
348
        called.  It depends upon the Medium whether or not the bytes will be
 
349
        immediately transmitted. Message based Mediums will tend to buffer the
 
350
        bytes until finished_writing() is called.
 
351
 
 
352
        :param bytes: A bytestring.
 
353
        """
 
354
        if self._state != "writing":
 
355
            raise errors.WritingCompleted(self)
 
356
        self._accept_bytes(bytes)
 
357
 
 
358
    def _accept_bytes(self, bytes):
 
359
        """Helper for accept_bytes.
 
360
 
 
361
        Accept_bytes checks the state of the request to determing if bytes
 
362
        should be accepted. After that it hands off to _accept_bytes to do the
 
363
        actual acceptance.
 
364
        """
 
365
        raise NotImplementedError(self._accept_bytes)
 
366
 
 
367
    def finished_reading(self):
 
368
        """Inform the request that all desired data has been read.
 
369
 
 
370
        This will remove the request from the pipeline for its medium (if the
 
371
        medium supports pipelining) and any further calls to methods on the
 
372
        request will raise ReadingCompleted.
 
373
        """
 
374
        if self._state == "writing":
 
375
            raise errors.WritingNotComplete(self)
 
376
        if self._state != "reading":
 
377
            raise errors.ReadingCompleted(self)
 
378
        self._state = "done"
 
379
        self._finished_reading()
 
380
 
 
381
    def _finished_reading(self):
 
382
        """Helper for finished_reading.
 
383
 
 
384
        finished_reading checks the state of the request to determine if 
 
385
        finished_reading is allowed, and if it is hands off to _finished_reading
 
386
        to perform the action.
 
387
        """
 
388
        raise NotImplementedError(self._finished_reading)
 
389
 
 
390
    def finished_writing(self):
 
391
        """Finish the writing phase of this request.
 
392
 
 
393
        This will flush all pending data for this request along the medium.
 
394
        After calling finished_writing, you may not call accept_bytes anymore.
 
395
        """
 
396
        if self._state != "writing":
 
397
            raise errors.WritingCompleted(self)
 
398
        self._state = "reading"
 
399
        self._finished_writing()
 
400
 
 
401
    def _finished_writing(self):
 
402
        """Helper for finished_writing.
 
403
 
 
404
        finished_writing checks the state of the request to determine if 
 
405
        finished_writing is allowed, and if it is hands off to _finished_writing
 
406
        to perform the action.
 
407
        """
 
408
        raise NotImplementedError(self._finished_writing)
 
409
 
 
410
    def read_bytes(self, count):
 
411
        """Read bytes from this requests response.
 
412
 
 
413
        This method will block and wait for count bytes to be read. It may not
 
414
        be invoked until finished_writing() has been called - this is to ensure
 
415
        a message-based approach to requests, for compatibility with message
 
416
        based mediums like HTTP.
 
417
        """
 
418
        if self._state == "writing":
 
419
            raise errors.WritingNotComplete(self)
 
420
        if self._state != "reading":
 
421
            raise errors.ReadingCompleted(self)
 
422
        return self._read_bytes(count)
 
423
 
 
424
    def _read_bytes(self, count):
 
425
        """See SmartClientMediumRequest._read_bytes.
 
426
        
 
427
        This forwards to self._medium._read_bytes because we are operating
 
428
        on the mediums stream.
 
429
        """
 
430
        return self._medium._read_bytes(count)
 
431
 
 
432
    def read_line(self):
 
433
        line = self._medium._get_line()
 
434
        if not line.endswith('\n'):
 
435
            # end of file encountered reading from server
 
436
            raise errors.ConnectionReset(
 
437
                "please check connectivity and permissions",
 
438
                "(and try -Dhpss if further diagnosis is required)")
 
439
        return line
 
440
 
 
441
 
 
442
class SmartClientMedium(SmartMedium):
 
443
    """Smart client is a medium for sending smart protocol requests over."""
 
444
 
 
445
    def __init__(self, base):
 
446
        super(SmartClientMedium, self).__init__()
 
447
        self.base = base
 
448
        self._protocol_version_error = None
 
449
        self._protocol_version = None
 
450
        self._done_hello = False
 
451
        # Be optimistic: we assume the remote end can accept new remote
 
452
        # requests until we get an error saying otherwise.
 
453
        # _remote_version_is_before tracks the bzr version the remote side
 
454
        # can be based on what we've seen so far.
 
455
        self._remote_version_is_before = None
 
456
 
 
457
    def _is_remote_before(self, version_tuple):
 
458
        """Is it possible the remote side supports RPCs for a given version?
 
459
 
 
460
        Typical use::
 
461
 
 
462
            needed_version = (1, 2)
 
463
            if medium._is_remote_before(needed_version):
 
464
                fallback_to_pre_1_2_rpc()
 
465
            else:
 
466
                try:
 
467
                    do_1_2_rpc()
 
468
                except UnknownSmartMethod:
 
469
                    medium._remember_remote_is_before(needed_version)
 
470
                    fallback_to_pre_1_2_rpc()
 
471
 
 
472
        :seealso: _remember_remote_is_before
 
473
        """
 
474
        if self._remote_version_is_before is None:
 
475
            # So far, the remote side seems to support everything
 
476
            return False
 
477
        return version_tuple >= self._remote_version_is_before
 
478
 
 
479
    def _remember_remote_is_before(self, version_tuple):
 
480
        """Tell this medium that the remote side is older the given version.
 
481
 
 
482
        :seealso: _is_remote_before
 
483
        """
 
484
        if (self._remote_version_is_before is not None and
 
485
            version_tuple > self._remote_version_is_before):
 
486
            raise AssertionError(
 
487
                "_remember_remote_is_before(%r) called, but "
 
488
                "_remember_remote_is_before(%r) was called previously."
 
489
                % (version_tuple, self._remote_version_is_before))
 
490
        self._remote_version_is_before = version_tuple
 
491
 
 
492
    def protocol_version(self):
 
493
        """Find out if 'hello' smart request works."""
 
494
        if self._protocol_version_error is not None:
 
495
            raise self._protocol_version_error
 
496
        if not self._done_hello:
 
497
            try:
 
498
                medium_request = self.get_request()
 
499
                # Send a 'hello' request in protocol version one, for maximum
 
500
                # backwards compatibility.
 
501
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
 
502
                client_protocol.query_version()
 
503
                self._done_hello = True
 
504
            except errors.SmartProtocolError, e:
 
505
                # Cache the error, just like we would cache a successful
 
506
                # result.
 
507
                self._protocol_version_error = e
 
508
                raise
 
509
        return '2'
 
510
 
 
511
    def should_probe(self):
 
512
        """Should RemoteBzrDirFormat.probe_transport send a smart request on
 
513
        this medium?
 
514
 
 
515
        Some transports are unambiguously smart-only; there's no need to check
 
516
        if the transport is able to carry smart requests, because that's all
 
517
        it is for.  In those cases, this method should return False.
 
518
 
 
519
        But some HTTP transports can sometimes fail to carry smart requests,
 
520
        but still be usuable for accessing remote bzrdirs via plain file
 
521
        accesses.  So for those transports, their media should return True here
 
522
        so that RemoteBzrDirFormat can determine if it is appropriate for that
 
523
        transport.
 
524
        """
 
525
        return False
 
526
 
 
527
    def disconnect(self):
 
528
        """If this medium maintains a persistent connection, close it.
 
529
        
 
530
        The default implementation does nothing.
 
531
        """
 
532
        
 
533
    def remote_path_from_transport(self, transport):
 
534
        """Convert transport into a path suitable for using in a request.
 
535
        
 
536
        Note that the resulting remote path doesn't encode the host name or
 
537
        anything but path, so it is only safe to use it in requests sent over
 
538
        the medium from the matching transport.
 
539
        """
 
540
        medium_base = urlutils.join(self.base, '/')
 
541
        rel_url = urlutils.relative_url(medium_base, transport.base)
 
542
        return urllib.unquote(rel_url)
 
543
 
 
544
 
 
545
class SmartClientStreamMedium(SmartClientMedium):
 
546
    """Stream based medium common class.
 
547
 
 
548
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
549
    SmartClientStreamMediumRequest for their requests, and should implement
 
550
    _accept_bytes and _read_bytes to allow the request objects to send and
 
551
    receive bytes.
 
552
    """
 
553
 
 
554
    def __init__(self, base):
 
555
        SmartClientMedium.__init__(self, base)
 
556
        self._current_request = None
 
557
 
 
558
    def accept_bytes(self, bytes):
 
559
        self._accept_bytes(bytes)
 
560
 
 
561
    def __del__(self):
 
562
        """The SmartClientStreamMedium knows how to close the stream when it is
 
563
        finished with it.
 
564
        """
 
565
        self.disconnect()
 
566
 
 
567
    def _flush(self):
 
568
        """Flush the output stream.
 
569
        
 
570
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
571
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
572
        """
 
573
        raise NotImplementedError(self._flush)
 
574
 
 
575
    def get_request(self):
 
576
        """See SmartClientMedium.get_request().
 
577
 
 
578
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
579
        for get_request.
 
580
        """
 
581
        return SmartClientStreamMediumRequest(self)
 
582
 
 
583
    def read_bytes(self, count):
 
584
        return self._read_bytes(count)
 
585
 
 
586
 
 
587
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
588
    """A client medium using simple pipes.
 
589
    
 
590
    This client does not manage the pipes: it assumes they will always be open.
 
591
    """
 
592
 
 
593
    def __init__(self, readable_pipe, writeable_pipe, base):
 
594
        SmartClientStreamMedium.__init__(self, base)
 
595
        self._readable_pipe = readable_pipe
 
596
        self._writeable_pipe = writeable_pipe
 
597
 
 
598
    def _accept_bytes(self, bytes):
 
599
        """See SmartClientStreamMedium.accept_bytes."""
 
600
        self._writeable_pipe.write(bytes)
 
601
 
 
602
    def _flush(self):
 
603
        """See SmartClientStreamMedium._flush()."""
 
604
        self._writeable_pipe.flush()
 
605
 
 
606
    def _read_bytes(self, count):
 
607
        """See SmartClientStreamMedium._read_bytes."""
 
608
        if self._push_back_buffer is not None:
 
609
            return self._get_push_back_buffer()
 
610
        return self._readable_pipe.read(count)
 
611
 
 
612
 
 
613
class SmartSSHClientMedium(SmartClientStreamMedium):
 
614
    """A client medium using SSH."""
 
615
    
 
616
    def __init__(self, host, port=None, username=None, password=None,
 
617
            base=None, vendor=None, bzr_remote_path=None):
 
618
        """Creates a client that will connect on the first use.
 
619
        
 
620
        :param vendor: An optional override for the ssh vendor to use. See
 
621
            bzrlib.transport.ssh for details on ssh vendors.
 
622
        """
 
623
        SmartClientStreamMedium.__init__(self, base)
 
624
        self._connected = False
 
625
        self._host = host
 
626
        self._password = password
 
627
        self._port = port
 
628
        self._username = username
 
629
        self._read_from = None
 
630
        self._ssh_connection = None
 
631
        self._vendor = vendor
 
632
        self._write_to = None
 
633
        self._bzr_remote_path = bzr_remote_path
 
634
        if self._bzr_remote_path is None:
 
635
            symbol_versioning.warn(
 
636
                'bzr_remote_path is required as of bzr 0.92',
 
637
                DeprecationWarning, stacklevel=2)
 
638
            self._bzr_remote_path = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
639
 
 
640
    def _accept_bytes(self, bytes):
 
641
        """See SmartClientStreamMedium.accept_bytes."""
 
642
        self._ensure_connection()
 
643
        self._write_to.write(bytes)
 
644
 
 
645
    def disconnect(self):
 
646
        """See SmartClientMedium.disconnect()."""
 
647
        if not self._connected:
 
648
            return
 
649
        self._read_from.close()
 
650
        self._write_to.close()
 
651
        self._ssh_connection.close()
 
652
        self._connected = False
 
653
 
 
654
    def _ensure_connection(self):
 
655
        """Connect this medium if not already connected."""
 
656
        if self._connected:
 
657
            return
 
658
        if self._vendor is None:
 
659
            vendor = ssh._get_ssh_vendor()
 
660
        else:
 
661
            vendor = self._vendor
 
662
        self._ssh_connection = vendor.connect_ssh(self._username,
 
663
                self._password, self._host, self._port,
 
664
                command=[self._bzr_remote_path, 'serve', '--inet',
 
665
                         '--directory=/', '--allow-writes'])
 
666
        self._read_from, self._write_to = \
 
667
            self._ssh_connection.get_filelike_channels()
 
668
        self._connected = True
 
669
 
 
670
    def _flush(self):
 
671
        """See SmartClientStreamMedium._flush()."""
 
672
        self._write_to.flush()
 
673
 
 
674
    def _read_bytes(self, count):
 
675
        """See SmartClientStreamMedium.read_bytes."""
 
676
        if not self._connected:
 
677
            raise errors.MediumNotConnected(self)
 
678
        # Read no more than 64k at a time so that we don't risk error 10053 or
 
679
        # 10055 on Windows (no buffer space available).
 
680
        max_read = 64 * 1024
 
681
        bytes_to_read = min(count, max_read)
 
682
        return self._read_from.read(bytes_to_read)
 
683
 
 
684
 
 
685
# Port 4155 is the default port for bzr://, registered with IANA.
 
686
BZR_DEFAULT_INTERFACE = '0.0.0.0'
 
687
BZR_DEFAULT_PORT = 4155
 
688
 
 
689
 
 
690
class SmartTCPClientMedium(SmartClientStreamMedium):
 
691
    """A client medium using TCP."""
 
692
    
 
693
    def __init__(self, host, port, base):
 
694
        """Creates a client that will connect on the first use."""
 
695
        SmartClientStreamMedium.__init__(self, base)
 
696
        self._connected = False
 
697
        self._host = host
 
698
        self._port = port
 
699
        self._socket = None
 
700
 
 
701
    def _accept_bytes(self, bytes):
 
702
        """See SmartClientMedium.accept_bytes."""
 
703
        self._ensure_connection()
 
704
        osutils.send_all(self._socket, bytes)
 
705
 
 
706
    def disconnect(self):
 
707
        """See SmartClientMedium.disconnect()."""
 
708
        if not self._connected:
 
709
            return
 
710
        self._socket.close()
 
711
        self._socket = None
 
712
        self._connected = False
 
713
 
 
714
    def _ensure_connection(self):
 
715
        """Connect this medium if not already connected."""
 
716
        if self._connected:
 
717
            return
 
718
        self._socket = socket.socket()
 
719
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
720
        if self._port is None:
 
721
            port = BZR_DEFAULT_PORT
 
722
        else:
 
723
            port = int(self._port)
 
724
        try:
 
725
            self._socket.connect((self._host, port))
 
726
        except socket.error, err:
 
727
            # socket errors either have a (string) or (errno, string) as their
 
728
            # args.
 
729
            if type(err.args) is str:
 
730
                err_msg = err.args
 
731
            else:
 
732
                err_msg = err.args[1]
 
733
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
734
                    (self._host, port, err_msg))
 
735
        self._connected = True
 
736
 
 
737
    def _flush(self):
 
738
        """See SmartClientStreamMedium._flush().
 
739
        
 
740
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
741
        add a means to do a flush, but that can be done in the future.
 
742
        """
 
743
 
 
744
    def _read_bytes(self, count):
 
745
        """See SmartClientMedium.read_bytes."""
 
746
        if not self._connected:
 
747
            raise errors.MediumNotConnected(self)
 
748
        if self._push_back_buffer is not None:
 
749
            return self._get_push_back_buffer()
 
750
        # We ignore the desired count because on sockets it's more efficient to
 
751
        # read 64k at a time.  Also, we must not read any more than 64k at a
 
752
        # time so that we don't risk error 10053 or 10055 on Windows (no buffer
 
753
        # space available).
 
754
        return self._socket.recv(64 * 1024)
 
755
 
 
756
 
 
757
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
758
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
759
 
 
760
    def __init__(self, medium):
 
761
        SmartClientMediumRequest.__init__(self, medium)
 
762
        # check that we are safe concurrency wise. If some streams start
 
763
        # allowing concurrent requests - i.e. via multiplexing - then this
 
764
        # assert should be moved to SmartClientStreamMedium.get_request,
 
765
        # and the setting/unsetting of _current_request likewise moved into
 
766
        # that class : but its unneeded overhead for now. RBC 20060922
 
767
        if self._medium._current_request is not None:
 
768
            raise errors.TooManyConcurrentRequests(self._medium)
 
769
        self._medium._current_request = self
 
770
 
 
771
    def _accept_bytes(self, bytes):
 
772
        """See SmartClientMediumRequest._accept_bytes.
 
773
        
 
774
        This forwards to self._medium._accept_bytes because we are operating
 
775
        on the mediums stream.
 
776
        """
 
777
        self._medium._accept_bytes(bytes)
 
778
 
 
779
    def _finished_reading(self):
 
780
        """See SmartClientMediumRequest._finished_reading.
 
781
 
 
782
        This clears the _current_request on self._medium to allow a new 
 
783
        request to be created.
 
784
        """
 
785
        if self._medium._current_request is not self:
 
786
            raise AssertionError()
 
787
        self._medium._current_request = None
 
788
        
 
789
    def _finished_writing(self):
 
790
        """See SmartClientMediumRequest._finished_writing.
 
791
 
 
792
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
793
        """
 
794
        self._medium._flush()
 
795
 
 
796
    def _read_bytes(self, count):
 
797
        """See SmartClientMediumRequest._read_bytes.
 
798
        
 
799
        This forwards to self._medium._read_bytes because we are operating
 
800
        on the mediums stream.
 
801
        """
 
802
        return self._medium._read_bytes(count)
 
803