/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: John Arbash Meinel
  • Author(s): Mark Hammond
  • Date: 2008-09-09 17:02:21 UTC
  • mto: This revision was merged to the branch mainline in revision 3697.
  • Revision ID: john@arbash-meinel.com-20080909170221-svim3jw2mrz0amp3
An updated transparent icon for bzr.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
16
 
17
17
"""The 'medium' layer for the smart servers and clients.
18
18
 
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
from __future__ import absolute_import
28
 
 
29
 
import errno
30
 
import io
31
27
import os
 
28
import socket
32
29
import sys
33
 
import time
 
30
import urllib
34
31
 
35
 
import breezy
36
 
from ...lazy_import import lazy_import
 
32
from bzrlib.lazy_import import lazy_import
37
33
lazy_import(globals(), """
38
 
import select
39
 
import socket
40
 
import thread
41
 
import weakref
42
 
 
43
 
from breezy import (
44
 
    debug,
45
 
    trace,
46
 
    transport,
47
 
    ui,
 
34
from bzrlib import (
 
35
    errors,
 
36
    osutils,
 
37
    symbol_versioning,
48
38
    urlutils,
49
39
    )
50
 
from breezy.i18n import gettext
51
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
52
 
from breezy.transport import ssh
 
40
from bzrlib.smart import protocol
 
41
from bzrlib.transport import ssh
53
42
""")
54
 
from ... import (
55
 
    errors,
56
 
    osutils,
57
 
    )
58
 
 
59
 
# Throughout this module buffer size parameters are either limited to be at
60
 
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
61
 
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
62
 
# from non-sockets as well.
63
 
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
64
 
 
65
 
 
66
 
class HpssVfsRequestNotAllowed(errors.BzrError):
67
 
 
68
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
69
 
            "%(method)s, %(arguments)s.")
70
 
 
71
 
    def __init__(self, method, arguments):
72
 
        self.method = method
73
 
        self.arguments = arguments
 
43
 
 
44
 
 
45
# We must not read any more than 64k at a time so we don't risk "no buffer
 
46
# space available" errors on some platforms.  Windows in particular is likely
 
47
# to give error 10053 or 10055 if we read more than 64k from a socket.
 
48
_MAX_READ_SIZE = 64 * 1024
74
49
 
75
50
 
76
51
def _get_protocol_factory_for_bytes(bytes):
107
82
 
108
83
def _get_line(read_bytes_func):
109
84
    """Read bytes using read_bytes_func until a newline byte.
110
 
 
 
85
    
111
86
    This isn't particularly efficient, so should only be used when the
112
87
    expected size of the line is quite short.
113
 
 
 
88
    
114
89
    :returns: a tuple of two strs: (line, excess)
115
90
    """
116
91
    newline_pos = -1
117
 
    bytes = b''
 
92
    bytes = ''
118
93
    while newline_pos == -1:
119
94
        new_bytes = read_bytes_func(1)
120
95
        bytes += new_bytes
121
 
        if new_bytes == b'':
 
96
        if new_bytes == '':
122
97
            # Ran out of bytes before receiving a complete line.
123
 
            return bytes, b''
124
 
        newline_pos = bytes.find(b'\n')
 
98
            return bytes, ''
 
99
        newline_pos = bytes.find('\n')
125
100
    line = bytes[:newline_pos+1]
126
101
    excess = bytes[newline_pos+1:]
127
102
    return line, excess
132
107
 
133
108
    def __init__(self):
134
109
        self._push_back_buffer = None
135
 
 
136
 
    def _push_back(self, data):
 
110
        
 
111
    def _push_back(self, bytes):
137
112
        """Return unused bytes to the medium, because they belong to the next
138
113
        request(s).
139
114
 
140
115
        This sets the _push_back_buffer to the given bytes.
141
116
        """
142
 
        if not isinstance(data, bytes):
143
 
            raise TypeError(data)
144
117
        if self._push_back_buffer is not None:
145
118
            raise AssertionError(
146
119
                "_push_back called when self._push_back_buffer is %r"
147
120
                % (self._push_back_buffer,))
148
 
        if data == b'':
 
121
        if bytes == '':
149
122
            return
150
 
        self._push_back_buffer = data
 
123
        self._push_back_buffer = bytes
151
124
 
152
125
    def _get_push_back_buffer(self):
153
 
        if self._push_back_buffer == b'':
 
126
        if self._push_back_buffer == '':
154
127
            raise AssertionError(
155
128
                '%s._push_back_buffer should never be the empty string, '
156
129
                'which can be confused with EOF' % (self,))
174
147
 
175
148
    def _get_line(self):
176
149
        """Read bytes from this request's response until a newline byte.
177
 
 
 
150
        
178
151
        This isn't particularly efficient, so should only be used when the
179
152
        expected size of the line is quite short.
180
153
 
183
156
        line, excess = _get_line(self.read_bytes)
184
157
        self._push_back(excess)
185
158
        return line
186
 
 
187
 
    def _report_activity(self, bytes, direction):
188
 
        """Notify that this medium has activity.
189
 
 
190
 
        Implementations should call this from all methods that actually do IO.
191
 
        Be careful that it's not called twice, if one method is implemented on
192
 
        top of another.
193
 
 
194
 
        :param bytes: Number of bytes read or written.
195
 
        :param direction: 'read' or 'write' or None.
196
 
        """
197
 
        ui.ui_factory.report_transport_activity(self, bytes, direction)
198
 
 
199
 
 
200
 
_bad_file_descriptor = (errno.EBADF,)
201
 
if sys.platform == 'win32':
202
 
    # Given on Windows if you pass a closed socket to select.select. Probably
203
 
    # also given if you pass a file handle to select.
204
 
    WSAENOTSOCK = 10038
205
 
    _bad_file_descriptor += (WSAENOTSOCK,)
206
 
 
 
159
 
207
160
 
208
161
class SmartServerStreamMedium(SmartMedium):
209
162
    """Handles smart commands coming over a stream.
214
167
    One instance is created for each connected client; it can serve multiple
215
168
    requests in the lifetime of the connection.
216
169
 
217
 
    The server passes requests through to an underlying backing transport,
 
170
    The server passes requests through to an underlying backing transport, 
218
171
    which will typically be a LocalTransport looking at the server's filesystem.
219
172
 
220
173
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
223
176
        the stream.  See also the _push_back method.
224
177
    """
225
178
 
226
 
    _timer = time.time
227
 
 
228
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
179
    def __init__(self, backing_transport, root_client_path='/'):
229
180
        """Construct new server.
230
181
 
231
182
        :param backing_transport: Transport for the directory served.
234
185
        self.backing_transport = backing_transport
235
186
        self.root_client_path = root_client_path
236
187
        self.finished = False
237
 
        if timeout is None:
238
 
            raise AssertionError('You must supply a timeout.')
239
 
        self._client_timeout = timeout
240
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
241
188
        SmartMedium.__init__(self)
242
189
 
243
190
    def serve(self):
249
196
            while not self.finished:
250
197
                server_protocol = self._build_protocol()
251
198
                self._serve_one_request(server_protocol)
252
 
        except errors.ConnectionTimeout as e:
253
 
            trace.note('%s' % (e,))
254
 
            trace.log_exception_quietly()
255
 
            self._disconnect_client()
256
 
            # We reported it, no reason to make a big fuss.
257
 
            return
258
 
        except Exception as e:
 
199
        except Exception, e:
259
200
            stderr.write("%s terminating on exception %s\n" % (self, e))
260
201
            raise
261
 
        self._disconnect_client()
262
 
 
263
 
    def _stop_gracefully(self):
264
 
        """When we finish this message, stop looking for more."""
265
 
        trace.mutter('Stopping %s' % (self,))
266
 
        self.finished = True
267
 
 
268
 
    def _disconnect_client(self):
269
 
        """Close the current connection. We stopped due to a timeout/etc."""
270
 
        # The default implementation is a no-op, because that is all we used to
271
 
        # do when disconnecting from a client. I suppose we never had the
272
 
        # *server* initiate a disconnect, before
273
 
 
274
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
275
 
        """Wait for more bytes to be read, but timeout if none available.
276
 
 
277
 
        This allows us to detect idle connections, and stop trying to read from
278
 
        them, without setting the socket itself to non-blocking. This also
279
 
        allows us to specify when we watch for idle timeouts.
280
 
 
281
 
        :return: Did we timeout? (True if we timed out, False if there is data
282
 
            to be read)
283
 
        """
284
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
285
202
 
286
203
    def _build_protocol(self):
287
204
        """Identifies the version of the incoming request, and returns an
292
209
 
293
210
        :returns: a SmartServerRequestProtocol.
294
211
        """
295
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
296
 
        if self.finished:
297
 
            # We're stopping, so don't try to do any more work
298
 
            return None
299
212
        bytes = self._get_line()
300
213
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
301
214
        protocol = protocol_factory(
303
216
        protocol.accept_bytes(unused_bytes)
304
217
        return protocol
305
218
 
306
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
307
 
        """select() on a file descriptor, waiting for nonblocking read()
308
 
 
309
 
        This will raise a ConnectionTimeout exception if we do not get a
310
 
        readable handle before timeout_seconds.
311
 
        :return: None
312
 
        """
313
 
        t_end = self._timer() + timeout_seconds
314
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
315
 
        rs = xs = None
316
 
        while not rs and not xs and self._timer() < t_end:
317
 
            if self.finished:
318
 
                return
319
 
            try:
320
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
321
 
            except (select.error, socket.error) as e:
322
 
                err = getattr(e, 'errno', None)
323
 
                if err is None and getattr(e, 'args', None) is not None:
324
 
                    # select.error doesn't have 'errno', it just has args[0]
325
 
                    err = e.args[0]
326
 
                if err in _bad_file_descriptor:
327
 
                    return # Not a socket indicates read() will fail
328
 
                elif err == errno.EINTR:
329
 
                    # Interrupted, keep looping.
330
 
                    continue
331
 
                raise
332
 
            except ValueError:
333
 
                return  # Socket may already be closed
334
 
        if rs or xs:
335
 
            return
336
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
337
 
                                       % (timeout_seconds,))
338
 
 
339
219
    def _serve_one_request(self, protocol):
340
220
        """Read one request from input, process, send back a response.
341
 
 
 
221
        
342
222
        :param protocol: a SmartServerRequestProtocol.
343
223
        """
344
 
        if protocol is None:
345
 
            return
346
224
        try:
347
225
            self._serve_one_request_unguarded(protocol)
348
226
        except KeyboardInterrupt:
349
227
            raise
350
 
        except Exception as e:
 
228
        except Exception, e:
351
229
            self.terminate_due_to_error()
352
230
 
353
231
    def terminate_due_to_error(self):
364
242
 
365
243
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
366
244
 
367
 
    def __init__(self, sock, backing_transport, root_client_path='/',
368
 
                 timeout=None):
 
245
    def __init__(self, sock, backing_transport, root_client_path='/'):
369
246
        """Constructor.
370
247
 
371
248
        :param sock: the socket the server will read from.  It will be put
372
249
            into blocking mode.
373
250
        """
374
251
        SmartServerStreamMedium.__init__(
375
 
            self, backing_transport, root_client_path=root_client_path,
376
 
            timeout=timeout)
 
252
            self, backing_transport, root_client_path=root_client_path)
377
253
        sock.setblocking(True)
378
254
        self.socket = sock
379
 
        # Get the getpeername now, as we might be closed later when we care.
380
 
        try:
381
 
            self._client_info = sock.getpeername()
382
 
        except socket.error:
383
 
            self._client_info = '<unknown>'
384
 
 
385
 
    def __str__(self):
386
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
387
 
 
388
 
    def __repr__(self):
389
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
390
 
            self._client_info)
391
255
 
392
256
    def _serve_one_request_unguarded(self, protocol):
393
257
        while protocol.next_read_size():
394
258
            # We can safely try to read large chunks.  If there is less data
395
 
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
396
 
            # short read immediately rather than block.
397
 
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
398
 
            if bytes == b'':
 
259
            # than _MAX_READ_SIZE ready, the socket wil just return a short
 
260
            # read immediately rather than block.
 
261
            bytes = self.read_bytes(_MAX_READ_SIZE)
 
262
            if bytes == '':
399
263
                self.finished = True
400
264
                return
401
265
            protocol.accept_bytes(bytes)
402
 
 
 
266
        
403
267
        self._push_back(protocol.unused_data)
404
268
 
405
 
    def _disconnect_client(self):
406
 
        """Close the current connection. We stopped due to a timeout/etc."""
407
 
        self.socket.close()
408
 
 
409
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
410
 
        """Wait for more bytes to be read, but timeout if none available.
411
 
 
412
 
        This allows us to detect idle connections, and stop trying to read from
413
 
        them, without setting the socket itself to non-blocking. This also
414
 
        allows us to specify when we watch for idle timeouts.
415
 
 
416
 
        :return: None, this will raise ConnectionTimeout if we time out before
417
 
            data is available.
418
 
        """
419
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
420
 
 
421
269
    def _read_bytes(self, desired_count):
422
 
        return osutils.read_bytes_from_socket(
423
 
            self.socket, self._report_activity)
 
270
        # We ignore the desired_count because on sockets it's more efficient to
 
271
        # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
272
        return self.socket.recv(_MAX_READ_SIZE)
424
273
 
425
274
    def terminate_due_to_error(self):
426
275
        # TODO: This should log to a server log file, but no such thing
429
278
        self.finished = True
430
279
 
431
280
    def _write_out(self, bytes):
432
 
        tstart = osutils.timer_func()
433
 
        osutils.send_all(self.socket, bytes, self._report_activity)
434
 
        if 'hpss' in debug.debug_flags:
435
 
            thread_id = thread.get_ident()
436
 
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
437
 
                         % ('wrote', thread_id, len(bytes),
438
 
                            osutils.timer_func() - tstart))
 
281
        osutils.send_all(self.socket, bytes)
439
282
 
440
283
 
441
284
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
442
285
 
443
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
286
    def __init__(self, in_file, out_file, backing_transport):
444
287
        """Construct new server.
445
288
 
446
289
        :param in_file: Python file from which requests can be read.
447
290
        :param out_file: Python file to write responses.
448
291
        :param backing_transport: Transport for the directory served.
449
292
        """
450
 
        SmartServerStreamMedium.__init__(self, backing_transport,
451
 
            timeout=timeout)
 
293
        SmartServerStreamMedium.__init__(self, backing_transport)
452
294
        if sys.platform == 'win32':
453
295
            # force binary mode for files
454
296
            import msvcrt
459
301
        self._in = in_file
460
302
        self._out = out_file
461
303
 
462
 
    def serve(self):
463
 
        """See SmartServerStreamMedium.serve"""
464
 
        # This is the regular serve, except it adds signal trapping for soft
465
 
        # shutdown.
466
 
        stop_gracefully = self._stop_gracefully
467
 
        signals.register_on_hangup(id(self), stop_gracefully)
468
 
        try:
469
 
            return super(SmartServerPipeStreamMedium, self).serve()
470
 
        finally:
471
 
            signals.unregister_on_hangup(id(self))
472
 
 
473
304
    def _serve_one_request_unguarded(self, protocol):
474
305
        while True:
475
306
            # We need to be careful not to read past the end of the current
481
312
                self._out.flush()
482
313
                return
483
314
            bytes = self.read_bytes(bytes_to_read)
484
 
            if bytes == b'':
 
315
            if bytes == '':
485
316
                # Connection has been closed.
486
317
                self.finished = True
487
318
                self._out.flush()
488
319
                return
489
320
            protocol.accept_bytes(bytes)
490
321
 
491
 
    def _disconnect_client(self):
492
 
        self._in.close()
493
 
        self._out.flush()
494
 
        self._out.close()
495
 
 
496
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
497
 
        """Wait for more bytes to be read, but timeout if none available.
498
 
 
499
 
        This allows us to detect idle connections, and stop trying to read from
500
 
        them, without setting the socket itself to non-blocking. This also
501
 
        allows us to specify when we watch for idle timeouts.
502
 
 
503
 
        :return: None, this will raise ConnectionTimeout if we time out before
504
 
            data is available.
505
 
        """
506
 
        if (getattr(self._in, 'fileno', None) is None
507
 
            or sys.platform == 'win32'):
508
 
            # You can't select() file descriptors on Windows.
509
 
            return
510
 
        try:
511
 
            return self._wait_on_descriptor(self._in, timeout_seconds)
512
 
        except io.UnsupportedOperation:
513
 
            return
514
 
 
515
322
    def _read_bytes(self, desired_count):
516
323
        return self._in.read(desired_count)
517
324
 
538
345
    request.finished_reading()
539
346
 
540
347
    It is up to the individual SmartClientMedium whether multiple concurrent
541
 
    requests can exist. See SmartClientMedium.get_request to obtain instances
542
 
    of SmartClientMediumRequest, and the concrete Medium you are using for
 
348
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
349
    of SmartClientMediumRequest, and the concrete Medium you are using for 
543
350
    details on concurrency and pipelining.
544
351
    """
545
352
 
554
361
    def accept_bytes(self, bytes):
555
362
        """Accept bytes for inclusion in this request.
556
363
 
557
 
        This method may not be called after finished_writing() has been
 
364
        This method may not be be called after finished_writing() has been
558
365
        called.  It depends upon the Medium whether or not the bytes will be
559
366
        immediately transmitted. Message based Mediums will tend to buffer the
560
367
        bytes until finished_writing() is called.
591
398
    def _finished_reading(self):
592
399
        """Helper for finished_reading.
593
400
 
594
 
        finished_reading checks the state of the request to determine if
 
401
        finished_reading checks the state of the request to determine if 
595
402
        finished_reading is allowed, and if it is hands off to _finished_reading
596
403
        to perform the action.
597
404
        """
611
418
    def _finished_writing(self):
612
419
        """Helper for finished_writing.
613
420
 
614
 
        finished_writing checks the state of the request to determine if
 
421
        finished_writing checks the state of the request to determine if 
615
422
        finished_writing is allowed, and if it is hands off to _finished_writing
616
423
        to perform the action.
617
424
        """
637
444
        read_bytes checks the state of the request to determing if bytes
638
445
        should be read. After that it hands off to _read_bytes to do the
639
446
        actual read.
640
 
 
 
447
        
641
448
        By default this forwards to self._medium.read_bytes because we are
642
449
        operating on the medium's stream.
643
450
        """
645
452
 
646
453
    def read_line(self):
647
454
        line = self._read_line()
648
 
        if not line.endswith(b'\n'):
 
455
        if not line.endswith('\n'):
649
456
            # end of file encountered reading from server
650
457
            raise errors.ConnectionReset(
651
 
                "Unexpected end of message. Please check connectivity "
652
 
                "and permissions, and report a bug if problems persist.")
 
458
                "please check connectivity and permissions",
 
459
                "(and try -Dhpss if further diagnosis is required)")
653
460
        return line
654
461
 
655
462
    def _read_line(self):
656
463
        """Helper for SmartClientMediumRequest.read_line.
657
 
 
 
464
        
658
465
        By default this forwards to self._medium._get_line because we are
659
466
        operating on the medium's stream.
660
467
        """
661
468
        return self._medium._get_line()
662
469
 
663
470
 
664
 
class _VfsRefuser(object):
665
 
    """An object that refuses all VFS requests.
666
 
 
667
 
    """
668
 
 
669
 
    def __init__(self):
670
 
        client._SmartClient.hooks.install_named_hook(
671
 
            'call', self.check_vfs, 'vfs refuser')
672
 
 
673
 
    def check_vfs(self, params):
674
 
        try:
675
 
            request_method = request.request_handlers.get(params.method)
676
 
        except KeyError:
677
 
            # A method we don't know about doesn't count as a VFS method.
678
 
            return
679
 
        if issubclass(request_method, vfs.VfsRequest):
680
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
681
 
 
682
 
 
683
 
class _DebugCounter(object):
684
 
    """An object that counts the HPSS calls made to each client medium.
685
 
 
686
 
    When a medium is garbage-collected, or failing that when
687
 
    breezy.global_state exits, the total number of calls made on that medium
688
 
    are reported via trace.note.
689
 
    """
690
 
 
691
 
    def __init__(self):
692
 
        self.counts = weakref.WeakKeyDictionary()
693
 
        client._SmartClient.hooks.install_named_hook(
694
 
            'call', self.increment_call_count, 'hpss call counter')
695
 
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
696
 
 
697
 
    def track(self, medium):
698
 
        """Start tracking calls made to a medium.
699
 
 
700
 
        This only keeps a weakref to the medium, so shouldn't affect the
701
 
        medium's lifetime.
702
 
        """
703
 
        medium_repr = repr(medium)
704
 
        # Add this medium to the WeakKeyDictionary
705
 
        self.counts[medium] = dict(count=0, vfs_count=0,
706
 
                                   medium_repr=medium_repr)
707
 
        # Weakref callbacks are fired in reverse order of their association
708
 
        # with the referenced object.  So we add a weakref *after* adding to
709
 
        # the WeakKeyDict so that we can report the value from it before the
710
 
        # entry is removed by the WeakKeyDict's own callback.
711
 
        ref = weakref.ref(medium, self.done)
712
 
 
713
 
    def increment_call_count(self, params):
714
 
        # Increment the count in the WeakKeyDictionary
715
 
        value = self.counts[params.medium]
716
 
        value['count'] += 1
717
 
        try:
718
 
            request_method = request.request_handlers.get(params.method)
719
 
        except KeyError:
720
 
            # A method we don't know about doesn't count as a VFS method.
721
 
            return
722
 
        if issubclass(request_method, vfs.VfsRequest):
723
 
            value['vfs_count'] += 1
724
 
 
725
 
    def done(self, ref):
726
 
        value = self.counts[ref]
727
 
        count, vfs_count, medium_repr = (
728
 
            value['count'], value['vfs_count'], value['medium_repr'])
729
 
        # In case this callback is invoked for the same ref twice (by the
730
 
        # weakref callback and by the atexit function), set the call count back
731
 
        # to 0 so this item won't be reported twice.
732
 
        value['count'] = 0
733
 
        value['vfs_count'] = 0
734
 
        if count != 0:
735
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
736
 
                       count, vfs_count, medium_repr))
737
 
 
738
 
    def flush_all(self):
739
 
        for ref in list(self.counts.keys()):
740
 
            self.done(ref)
741
 
 
742
 
_debug_counter = None
743
 
_vfs_refuser = None
744
 
 
745
 
 
746
471
class SmartClientMedium(SmartMedium):
747
472
    """Smart client is a medium for sending smart protocol requests over."""
748
473
 
757
482
        # _remote_version_is_before tracks the bzr version the remote side
758
483
        # can be based on what we've seen so far.
759
484
        self._remote_version_is_before = None
760
 
        # Install debug hook function if debug flag is set.
761
 
        if 'hpss' in debug.debug_flags:
762
 
            global _debug_counter
763
 
            if _debug_counter is None:
764
 
                _debug_counter = _DebugCounter()
765
 
            _debug_counter.track(self)
766
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
767
 
            global _vfs_refuser
768
 
            if _vfs_refuser is None:
769
 
                _vfs_refuser = _VfsRefuser()
770
485
 
771
486
    def _is_remote_before(self, version_tuple):
772
487
        """Is it possible the remote side supports RPCs for a given version?
797
512
        """
798
513
        if (self._remote_version_is_before is not None and
799
514
            version_tuple > self._remote_version_is_before):
800
 
            # We have been told that the remote side is older than some version
801
 
            # which is newer than a previously supplied older-than version.
802
 
            # This indicates that some smart verb call is not guarded
803
 
            # appropriately (it should simply not have been tried).
804
 
            trace.mutter(
 
515
            raise AssertionError(
805
516
                "_remember_remote_is_before(%r) called, but "
806
517
                "_remember_remote_is_before(%r) was called previously."
807
 
                , version_tuple, self._remote_version_is_before)
808
 
            if 'hpss' in debug.debug_flags:
809
 
                ui.ui_factory.show_warning(
810
 
                    "_remember_remote_is_before(%r) called, but "
811
 
                    "_remember_remote_is_before(%r) was called previously."
812
 
                    % (version_tuple, self._remote_version_is_before))
813
 
            return
 
518
                % (version_tuple, self._remote_version_is_before))
814
519
        self._remote_version_is_before = version_tuple
815
520
 
816
521
    def protocol_version(self):
825
530
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
826
531
                client_protocol.query_version()
827
532
                self._done_hello = True
828
 
            except errors.SmartProtocolError as e:
 
533
            except errors.SmartProtocolError, e:
829
534
                # Cache the error, just like we would cache a successful
830
535
                # result.
831
536
                self._protocol_version_error = e
850
555
 
851
556
    def disconnect(self):
852
557
        """If this medium maintains a persistent connection, close it.
853
 
 
 
558
        
854
559
        The default implementation does nothing.
855
560
        """
856
 
 
 
561
        
857
562
    def remote_path_from_transport(self, transport):
858
563
        """Convert transport into a path suitable for using in a request.
859
 
 
 
564
        
860
565
        Note that the resulting remote path doesn't encode the host name or
861
566
        anything but path, so it is only safe to use it in requests sent over
862
567
        the medium from the matching transport.
863
568
        """
864
569
        medium_base = urlutils.join(self.base, '/')
865
570
        rel_url = urlutils.relative_url(medium_base, transport.base)
866
 
        return urlutils.unquote(rel_url)
 
571
        return urllib.unquote(rel_url)
867
572
 
868
573
 
869
574
class SmartClientStreamMedium(SmartClientMedium):
890
595
 
891
596
    def _flush(self):
892
597
        """Flush the output stream.
893
 
 
 
598
        
894
599
        This method is used by the SmartClientStreamMediumRequest to ensure that
895
600
        all data for a request is sent, to avoid long timeouts or deadlocks.
896
601
        """
904
609
        """
905
610
        return SmartClientStreamMediumRequest(self)
906
611
 
907
 
    def reset(self):
908
 
        """We have been disconnected, reset current state.
909
 
 
910
 
        This resets things like _current_request and connected state.
911
 
        """
912
 
        self.disconnect()
913
 
        self._current_request = None
914
 
 
915
612
 
916
613
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
917
614
    """A client medium using simple pipes.
918
 
 
 
615
    
919
616
    This client does not manage the pipes: it assumes they will always be open.
920
617
    """
921
618
 
924
621
        self._readable_pipe = readable_pipe
925
622
        self._writeable_pipe = writeable_pipe
926
623
 
927
 
    def _accept_bytes(self, data):
 
624
    def _accept_bytes(self, bytes):
928
625
        """See SmartClientStreamMedium.accept_bytes."""
929
 
        try:
930
 
            self._writeable_pipe.write(data)
931
 
        except IOError as e:
932
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
933
 
                raise errors.ConnectionReset(
934
 
                    "Error trying to write to subprocess", e)
935
 
            raise
936
 
        self._report_activity(len(data), 'write')
 
626
        self._writeable_pipe.write(bytes)
937
627
 
938
628
    def _flush(self):
939
629
        """See SmartClientStreamMedium._flush()."""
940
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
941
 
        #       However, testing shows that even when the child process is
942
 
        #       gone, this doesn't error.
943
630
        self._writeable_pipe.flush()
944
631
 
945
632
    def _read_bytes(self, count):
946
633
        """See SmartClientStreamMedium._read_bytes."""
947
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
948
 
        data = self._readable_pipe.read(bytes_to_read)
949
 
        self._report_activity(len(data), 'read')
950
 
        return data
951
 
 
952
 
 
953
 
class SSHParams(object):
954
 
    """A set of parameters for starting a remote bzr via SSH."""
955
 
 
 
634
        return self._readable_pipe.read(count)
 
635
 
 
636
 
 
637
class SmartSSHClientMedium(SmartClientStreamMedium):
 
638
    """A client medium using SSH."""
 
639
    
956
640
    def __init__(self, host, port=None, username=None, password=None,
957
 
            bzr_remote_path='bzr'):
958
 
        self.host = host
959
 
        self.port = port
960
 
        self.username = username
961
 
        self.password = password
962
 
        self.bzr_remote_path = bzr_remote_path
963
 
 
964
 
 
965
 
class SmartSSHClientMedium(SmartClientStreamMedium):
966
 
    """A client medium using SSH.
967
 
 
968
 
    It delegates IO to a SmartSimplePipesClientMedium or
969
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
970
 
    """
971
 
 
972
 
    def __init__(self, base, ssh_params, vendor=None):
 
641
            base=None, vendor=None, bzr_remote_path=None):
973
642
        """Creates a client that will connect on the first use.
974
 
 
975
 
        :param ssh_params: A SSHParams instance.
 
643
        
976
644
        :param vendor: An optional override for the ssh vendor to use. See
977
 
            breezy.transport.ssh for details on ssh vendors.
 
645
            bzrlib.transport.ssh for details on ssh vendors.
978
646
        """
979
 
        self._real_medium = None
980
 
        self._ssh_params = ssh_params
981
 
        # for the benefit of progress making a short description of this
982
 
        # transport
983
 
        self._scheme = 'bzr+ssh'
984
 
        # SmartClientStreamMedium stores the repr of this object in its
985
 
        # _DebugCounter so we have to store all the values used in our repr
986
 
        # method before calling the super init.
987
647
        SmartClientStreamMedium.__init__(self, base)
 
648
        self._connected = False
 
649
        self._host = host
 
650
        self._password = password
 
651
        self._port = port
 
652
        self._username = username
 
653
        self._read_from = None
 
654
        self._ssh_connection = None
988
655
        self._vendor = vendor
989
 
        self._ssh_connection = None
990
 
 
991
 
    def __repr__(self):
992
 
        if self._ssh_params.port is None:
993
 
            maybe_port = ''
994
 
        else:
995
 
            maybe_port = ':%s' % self._ssh_params.port
996
 
        if self._ssh_params.username is None:
997
 
            maybe_user = ''
998
 
        else:
999
 
            maybe_user = '%s@' % self._ssh_params.username
1000
 
        return "%s(%s://%s%s%s/)" % (
1001
 
            self.__class__.__name__,
1002
 
            self._scheme,
1003
 
            maybe_user,
1004
 
            self._ssh_params.host,
1005
 
            maybe_port)
 
656
        self._write_to = None
 
657
        self._bzr_remote_path = bzr_remote_path
 
658
        if self._bzr_remote_path is None:
 
659
            symbol_versioning.warn(
 
660
                'bzr_remote_path is required as of bzr 0.92',
 
661
                DeprecationWarning, stacklevel=2)
 
662
            self._bzr_remote_path = os.environ.get('BZR_REMOTE_PATH', 'bzr')
1006
663
 
1007
664
    def _accept_bytes(self, bytes):
1008
665
        """See SmartClientStreamMedium.accept_bytes."""
1009
666
        self._ensure_connection()
1010
 
        self._real_medium.accept_bytes(bytes)
 
667
        self._write_to.write(bytes)
1011
668
 
1012
669
    def disconnect(self):
1013
670
        """See SmartClientMedium.disconnect()."""
1014
 
        if self._real_medium is not None:
1015
 
            self._real_medium.disconnect()
1016
 
            self._real_medium = None
1017
 
        if self._ssh_connection is not None:
1018
 
            self._ssh_connection.close()
1019
 
            self._ssh_connection = None
 
671
        if not self._connected:
 
672
            return
 
673
        self._read_from.close()
 
674
        self._write_to.close()
 
675
        self._ssh_connection.close()
 
676
        self._connected = False
1020
677
 
1021
678
    def _ensure_connection(self):
1022
679
        """Connect this medium if not already connected."""
1023
 
        if self._real_medium is not None:
 
680
        if self._connected:
1024
681
            return
1025
682
        if self._vendor is None:
1026
683
            vendor = ssh._get_ssh_vendor()
1027
684
        else:
1028
685
            vendor = self._vendor
1029
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1030
 
                self._ssh_params.password, self._ssh_params.host,
1031
 
                self._ssh_params.port,
1032
 
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
686
        self._ssh_connection = vendor.connect_ssh(self._username,
 
687
                self._password, self._host, self._port,
 
688
                command=[self._bzr_remote_path, 'serve', '--inet',
1033
689
                         '--directory=/', '--allow-writes'])
1034
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1035
 
        if io_kind == 'socket':
1036
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1037
 
                self.base, io_object)
1038
 
        elif io_kind == 'pipes':
1039
 
            read_from, write_to = io_object
1040
 
            self._real_medium = SmartSimplePipesClientMedium(
1041
 
                read_from, write_to, self.base)
1042
 
        else:
1043
 
            raise AssertionError(
1044
 
                "Unexpected io_kind %r from %r"
1045
 
                % (io_kind, self._ssh_connection))
1046
 
        for hook in transport.Transport.hooks["post_connect"]:
1047
 
            hook(self)
 
690
        self._read_from, self._write_to = \
 
691
            self._ssh_connection.get_filelike_channels()
 
692
        self._connected = True
1048
693
 
1049
694
    def _flush(self):
1050
695
        """See SmartClientStreamMedium._flush()."""
1051
 
        self._real_medium._flush()
 
696
        self._write_to.flush()
1052
697
 
1053
698
    def _read_bytes(self, count):
1054
699
        """See SmartClientStreamMedium.read_bytes."""
1055
 
        if self._real_medium is None:
 
700
        if not self._connected:
1056
701
            raise errors.MediumNotConnected(self)
1057
 
        return self._real_medium.read_bytes(count)
 
702
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
703
        return self._read_from.read(bytes_to_read)
1058
704
 
1059
705
 
1060
706
# Port 4155 is the default port for bzr://, registered with IANA.
1061
 
BZR_DEFAULT_INTERFACE = None
 
707
BZR_DEFAULT_INTERFACE = '0.0.0.0'
1062
708
BZR_DEFAULT_PORT = 4155
1063
709
 
1064
710
 
1065
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1066
 
    """A client medium using a socket.
1067
 
 
1068
 
    This class isn't usable directly.  Use one of its subclasses instead.
1069
 
    """
1070
 
 
1071
 
    def __init__(self, base):
 
711
class SmartTCPClientMedium(SmartClientStreamMedium):
 
712
    """A client medium using TCP."""
 
713
    
 
714
    def __init__(self, host, port, base):
 
715
        """Creates a client that will connect on the first use."""
1072
716
        SmartClientStreamMedium.__init__(self, base)
 
717
        self._connected = False
 
718
        self._host = host
 
719
        self._port = port
1073
720
        self._socket = None
1074
 
        self._connected = False
1075
721
 
1076
722
    def _accept_bytes(self, bytes):
1077
723
        """See SmartClientMedium.accept_bytes."""
1078
724
        self._ensure_connection()
1079
 
        osutils.send_all(self._socket, bytes, self._report_activity)
1080
 
 
1081
 
    def _ensure_connection(self):
1082
 
        """Connect this medium if not already connected."""
1083
 
        raise NotImplementedError(self._ensure_connection)
1084
 
 
1085
 
    def _flush(self):
1086
 
        """See SmartClientStreamMedium._flush().
1087
 
 
1088
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1089
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1090
 
        future.
1091
 
        """
1092
 
 
1093
 
    def _read_bytes(self, count):
1094
 
        """See SmartClientMedium.read_bytes."""
1095
 
        if not self._connected:
1096
 
            raise errors.MediumNotConnected(self)
1097
 
        return osutils.read_bytes_from_socket(
1098
 
            self._socket, self._report_activity)
 
725
        osutils.send_all(self._socket, bytes)
1099
726
 
1100
727
    def disconnect(self):
1101
728
        """See SmartClientMedium.disconnect()."""
1105
732
        self._socket = None
1106
733
        self._connected = False
1107
734
 
1108
 
 
1109
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1110
 
    """A client medium that creates a TCP connection."""
1111
 
 
1112
 
    def __init__(self, host, port, base):
1113
 
        """Creates a client that will connect on the first use."""
1114
 
        SmartClientSocketMedium.__init__(self, base)
1115
 
        self._host = host
1116
 
        self._port = port
1117
 
 
1118
735
    def _ensure_connection(self):
1119
736
        """Connect this medium if not already connected."""
1120
737
        if self._connected:
1121
738
            return
 
739
        self._socket = socket.socket()
 
740
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1122
741
        if self._port is None:
1123
742
            port = BZR_DEFAULT_PORT
1124
743
        else:
1125
744
            port = int(self._port)
1126
745
        try:
1127
 
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1128
 
                socket.SOCK_STREAM, 0, 0)
1129
 
        except socket.gaierror as xxx_todo_changeme:
1130
 
            (err_num, err_msg) = xxx_todo_changeme.args
1131
 
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1132
 
                    (self._host, port, err_msg))
1133
 
        # Initialize err in case there are no addresses returned:
1134
 
        last_err = socket.error("no address found for %s" % self._host)
1135
 
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1136
 
            try:
1137
 
                self._socket = socket.socket(family, socktype, proto)
1138
 
                self._socket.setsockopt(socket.IPPROTO_TCP,
1139
 
                                        socket.TCP_NODELAY, 1)
1140
 
                self._socket.connect(sockaddr)
1141
 
            except socket.error as err:
1142
 
                if self._socket is not None:
1143
 
                    self._socket.close()
1144
 
                self._socket = None
1145
 
                last_err = err
1146
 
                continue
1147
 
            break
1148
 
        if self._socket is None:
 
746
            self._socket.connect((self._host, port))
 
747
        except socket.error, err:
1149
748
            # socket errors either have a (string) or (errno, string) as their
1150
749
            # args.
1151
 
            if isinstance(last_err.args, str):
1152
 
                err_msg = last_err.args
 
750
            if type(err.args) is str:
 
751
                err_msg = err.args
1153
752
            else:
1154
 
                err_msg = last_err.args[1]
 
753
                err_msg = err.args[1]
1155
754
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1156
755
                    (self._host, port, err_msg))
1157
756
        self._connected = True
1158
 
        for hook in transport.Transport.hooks["post_connect"]:
1159
 
            hook(self)
1160
 
 
1161
 
 
1162
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1163
 
    """A client medium for an already connected socket.
1164
 
    
1165
 
    Note that this class will assume it "owns" the socket, so it will close it
1166
 
    when its disconnect method is called.
1167
 
    """
1168
 
 
1169
 
    def __init__(self, base, sock):
1170
 
        SmartClientSocketMedium.__init__(self, base)
1171
 
        self._socket = sock
1172
 
        self._connected = True
1173
 
 
1174
 
    def _ensure_connection(self):
1175
 
        # Already connected, by definition!  So nothing to do.
1176
 
        pass
 
757
 
 
758
    def _flush(self):
 
759
        """See SmartClientStreamMedium._flush().
 
760
        
 
761
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
762
        add a means to do a flush, but that can be done in the future.
 
763
        """
 
764
 
 
765
    def _read_bytes(self, count):
 
766
        """See SmartClientMedium.read_bytes."""
 
767
        if not self._connected:
 
768
            raise errors.MediumNotConnected(self)
 
769
        # We ignore the desired_count because on sockets it's more efficient to
 
770
        # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
771
        return self._socket.recv(_MAX_READ_SIZE)
1177
772
 
1178
773
 
1179
774
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1192
787
 
1193
788
    def _accept_bytes(self, bytes):
1194
789
        """See SmartClientMediumRequest._accept_bytes.
1195
 
 
 
790
        
1196
791
        This forwards to self._medium._accept_bytes because we are operating
1197
792
        on the mediums stream.
1198
793
        """
1201
796
    def _finished_reading(self):
1202
797
        """See SmartClientMediumRequest._finished_reading.
1203
798
 
1204
 
        This clears the _current_request on self._medium to allow a new
 
799
        This clears the _current_request on self._medium to allow a new 
1205
800
        request to be created.
1206
801
        """
1207
802
        if self._medium._current_request is not self:
1208
803
            raise AssertionError()
1209
804
        self._medium._current_request = None
1210
 
 
 
805
        
1211
806
    def _finished_writing(self):
1212
807
        """See SmartClientMediumRequest._finished_writing.
1213
808
 
1214
809
        This invokes self._medium._flush to ensure all bytes are transmitted.
1215
810
        """
1216
811
        self._medium._flush()
 
812