1
# Copyright (C) 2006 Canonical Ltd
 
 
3
# This program is free software; you can redistribute it and/or modify
 
 
4
# it under the terms of the GNU General Public License as published by
 
 
5
# the Free Software Foundation; either version 2 of the License, or
 
 
6
# (at your option) any later version.
 
 
8
# This program is distributed in the hope that it will be useful,
 
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
 
11
# GNU General Public License for more details.
 
 
13
# You should have received a copy of the GNU General Public License
 
 
14
# along with this program; if not, write to the Free Software
 
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
 
17
"""The 'medium' layer for the smart servers and clients.
 
 
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
 
 
20
for "the quality between big and small."
 
 
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
 
 
23
over SSH), and pass them to and from the protocol logic.  See the overview in
 
 
24
bzrlib/transport/smart/__init__.py.
 
 
30
from bzrlib import errors
 
 
31
from bzrlib.smart import protocol
 
 
33
    from bzrlib.transport import ssh
 
 
34
except errors.ParamikoNotPresent:
 
 
35
    # no paramiko.  SmartSSHClientMedium will break.
 
 
39
class SmartServerStreamMedium(object):
 
 
40
    """Handles smart commands coming over a stream.
 
 
42
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
 
43
    in-process fifo for testing.
 
 
45
    One instance is created for each connected client; it can serve multiple
 
 
46
    requests in the lifetime of the connection.
 
 
48
    The server passes requests through to an underlying backing transport, 
 
 
49
    which will typically be a LocalTransport looking at the server's filesystem.
 
 
52
    def __init__(self, backing_transport):
 
 
53
        """Construct new server.
 
 
55
        :param backing_transport: Transport for the directory served.
 
 
57
        # backing_transport could be passed to serve instead of __init__
 
 
58
        self.backing_transport = backing_transport
 
 
62
        """Serve requests until the client disconnects."""
 
 
63
        # Keep a reference to stderr because the sys module's globals get set to
 
 
64
        # None during interpreter shutdown.
 
 
65
        from sys import stderr
 
 
67
            while not self.finished:
 
 
68
                server_protocol = protocol.SmartServerRequestProtocolOne(
 
 
69
                    self.backing_transport, self._write_out)
 
 
70
                self._serve_one_request(server_protocol)
 
 
72
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
 
75
    def _serve_one_request(self, protocol):
 
 
76
        """Read one request from input, process, send back a response.
 
 
78
        :param protocol: a SmartServerRequestProtocol.
 
 
81
            self._serve_one_request_unguarded(protocol)
 
 
82
        except KeyboardInterrupt:
 
 
85
            self.terminate_due_to_error()
 
 
87
    def terminate_due_to_error(self):
 
 
88
        """Called when an unhandled exception from the protocol occurs."""
 
 
89
        raise NotImplementedError(self.terminate_due_to_error)
 
 
92
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
 
94
    def __init__(self, sock, backing_transport):
 
 
97
        :param sock: the socket the server will read from.  It will be put
 
 
100
        SmartServerStreamMedium.__init__(self, backing_transport)
 
 
102
        sock.setblocking(True)
 
 
105
    def _serve_one_request_unguarded(self, protocol):
 
 
106
        while protocol.next_read_size():
 
 
108
                protocol.accept_bytes(self.push_back)
 
 
111
                bytes = self.socket.recv(4096)
 
 
115
                protocol.accept_bytes(bytes)
 
 
117
        self.push_back = protocol.excess_buffer
 
 
119
    def terminate_due_to_error(self):
 
 
120
        """Called when an unhandled exception from the protocol occurs."""
 
 
121
        # TODO: This should log to a server log file, but no such thing
 
 
122
        # exists yet.  Andrew Bennetts 2006-09-29.
 
 
126
    def _write_out(self, bytes):
 
 
127
        self.socket.sendall(bytes)
 
 
130
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
 
132
    def __init__(self, in_file, out_file, backing_transport):
 
 
133
        """Construct new server.
 
 
135
        :param in_file: Python file from which requests can be read.
 
 
136
        :param out_file: Python file to write responses.
 
 
137
        :param backing_transport: Transport for the directory served.
 
 
139
        SmartServerStreamMedium.__init__(self, backing_transport)
 
 
143
    def _serve_one_request_unguarded(self, protocol):
 
 
145
            bytes_to_read = protocol.next_read_size()
 
 
146
            if bytes_to_read == 0:
 
 
147
                # Finished serving this request.
 
 
150
            bytes = self._in.read(bytes_to_read)
 
 
152
                # Connection has been closed.
 
 
156
            protocol.accept_bytes(bytes)
 
 
158
    def terminate_due_to_error(self):
 
 
159
        # TODO: This should log to a server log file, but no such thing
 
 
160
        # exists yet.  Andrew Bennetts 2006-09-29.
 
 
164
    def _write_out(self, bytes):
 
 
165
        self._out.write(bytes)
 
 
168
class SmartClientMediumRequest(object):
 
 
169
    """A request on a SmartClientMedium.
 
 
171
    Each request allows bytes to be provided to it via accept_bytes, and then
 
 
172
    the response bytes to be read via read_bytes.
 
 
175
    request.accept_bytes('123')
 
 
176
    request.finished_writing()
 
 
177
    result = request.read_bytes(3)
 
 
178
    request.finished_reading()
 
 
180
    It is up to the individual SmartClientMedium whether multiple concurrent
 
 
181
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
 
182
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
 
183
    details on concurrency and pipelining.
 
 
186
    def __init__(self, medium):
 
 
187
        """Construct a SmartClientMediumRequest for the medium medium."""
 
 
188
        self._medium = medium
 
 
189
        # we track state by constants - we may want to use the same
 
 
190
        # pattern as BodyReader if it gets more complex.
 
 
191
        # valid states are: "writing", "reading", "done"
 
 
192
        self._state = "writing"
 
 
194
    def accept_bytes(self, bytes):
 
 
195
        """Accept bytes for inclusion in this request.
 
 
197
        This method may not be be called after finished_writing() has been
 
 
198
        called.  It depends upon the Medium whether or not the bytes will be
 
 
199
        immediately transmitted. Message based Mediums will tend to buffer the
 
 
200
        bytes until finished_writing() is called.
 
 
202
        :param bytes: A bytestring.
 
 
204
        if self._state != "writing":
 
 
205
            raise errors.WritingCompleted(self)
 
 
206
        self._accept_bytes(bytes)
 
 
208
    def _accept_bytes(self, bytes):
 
 
209
        """Helper for accept_bytes.
 
 
211
        Accept_bytes checks the state of the request to determing if bytes
 
 
212
        should be accepted. After that it hands off to _accept_bytes to do the
 
 
215
        raise NotImplementedError(self._accept_bytes)
 
 
217
    def finished_reading(self):
 
 
218
        """Inform the request that all desired data has been read.
 
 
220
        This will remove the request from the pipeline for its medium (if the
 
 
221
        medium supports pipelining) and any further calls to methods on the
 
 
222
        request will raise ReadingCompleted.
 
 
224
        if self._state == "writing":
 
 
225
            raise errors.WritingNotComplete(self)
 
 
226
        if self._state != "reading":
 
 
227
            raise errors.ReadingCompleted(self)
 
 
229
        self._finished_reading()
 
 
231
    def _finished_reading(self):
 
 
232
        """Helper for finished_reading.
 
 
234
        finished_reading checks the state of the request to determine if 
 
 
235
        finished_reading is allowed, and if it is hands off to _finished_reading
 
 
236
        to perform the action.
 
 
238
        raise NotImplementedError(self._finished_reading)
 
 
240
    def finished_writing(self):
 
 
241
        """Finish the writing phase of this request.
 
 
243
        This will flush all pending data for this request along the medium.
 
 
244
        After calling finished_writing, you may not call accept_bytes anymore.
 
 
246
        if self._state != "writing":
 
 
247
            raise errors.WritingCompleted(self)
 
 
248
        self._state = "reading"
 
 
249
        self._finished_writing()
 
 
251
    def _finished_writing(self):
 
 
252
        """Helper for finished_writing.
 
 
254
        finished_writing checks the state of the request to determine if 
 
 
255
        finished_writing is allowed, and if it is hands off to _finished_writing
 
 
256
        to perform the action.
 
 
258
        raise NotImplementedError(self._finished_writing)
 
 
260
    def read_bytes(self, count):
 
 
261
        """Read bytes from this requests response.
 
 
263
        This method will block and wait for count bytes to be read. It may not
 
 
264
        be invoked until finished_writing() has been called - this is to ensure
 
 
265
        a message-based approach to requests, for compatability with message
 
 
266
        based mediums like HTTP.
 
 
268
        if self._state == "writing":
 
 
269
            raise errors.WritingNotComplete(self)
 
 
270
        if self._state != "reading":
 
 
271
            raise errors.ReadingCompleted(self)
 
 
272
        return self._read_bytes(count)
 
 
274
    def _read_bytes(self, count):
 
 
275
        """Helper for read_bytes.
 
 
277
        read_bytes checks the state of the request to determing if bytes
 
 
278
        should be read. After that it hands off to _read_bytes to do the
 
 
281
        raise NotImplementedError(self._read_bytes)
 
 
284
class SmartClientMedium(object):
 
 
285
    """Smart client is a medium for sending smart protocol requests over."""
 
 
287
    def disconnect(self):
 
 
288
        """If this medium maintains a persistent connection, close it.
 
 
290
        The default implementation does nothing.
 
 
294
class SmartClientStreamMedium(SmartClientMedium):
 
 
295
    """Stream based medium common class.
 
 
297
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
 
298
    SmartClientStreamMediumRequest for their requests, and should implement
 
 
299
    _accept_bytes and _read_bytes to allow the request objects to send and
 
 
304
        self._current_request = None
 
 
306
    def accept_bytes(self, bytes):
 
 
307
        self._accept_bytes(bytes)
 
 
310
        """The SmartClientStreamMedium knows how to close the stream when it is
 
 
316
        """Flush the output stream.
 
 
318
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
 
319
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
 
321
        raise NotImplementedError(self._flush)
 
 
323
    def get_request(self):
 
 
324
        """See SmartClientMedium.get_request().
 
 
326
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
 
329
        return SmartClientStreamMediumRequest(self)
 
 
331
    def read_bytes(self, count):
 
 
332
        return self._read_bytes(count)
 
 
335
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
 
336
    """A client medium using simple pipes.
 
 
338
    This client does not manage the pipes: it assumes they will always be open.
 
 
341
    def __init__(self, readable_pipe, writeable_pipe):
 
 
342
        SmartClientStreamMedium.__init__(self)
 
 
343
        self._readable_pipe = readable_pipe
 
 
344
        self._writeable_pipe = writeable_pipe
 
 
346
    def _accept_bytes(self, bytes):
 
 
347
        """See SmartClientStreamMedium.accept_bytes."""
 
 
348
        self._writeable_pipe.write(bytes)
 
 
351
        """See SmartClientStreamMedium._flush()."""
 
 
352
        self._writeable_pipe.flush()
 
 
354
    def _read_bytes(self, count):
 
 
355
        """See SmartClientStreamMedium._read_bytes."""
 
 
356
        return self._readable_pipe.read(count)
 
 
359
class SmartSSHClientMedium(SmartClientStreamMedium):
 
 
360
    """A client medium using SSH."""
 
 
362
    def __init__(self, host, port=None, username=None, password=None,
 
 
364
        """Creates a client that will connect on the first use.
 
 
366
        :param vendor: An optional override for the ssh vendor to use. See
 
 
367
            bzrlib.transport.ssh for details on ssh vendors.
 
 
369
        SmartClientStreamMedium.__init__(self)
 
 
370
        self._connected = False
 
 
372
        self._password = password
 
 
374
        self._username = username
 
 
375
        self._read_from = None
 
 
376
        self._ssh_connection = None
 
 
377
        self._vendor = vendor
 
 
378
        self._write_to = None
 
 
380
    def _accept_bytes(self, bytes):
 
 
381
        """See SmartClientStreamMedium.accept_bytes."""
 
 
382
        self._ensure_connection()
 
 
383
        self._write_to.write(bytes)
 
 
385
    def disconnect(self):
 
 
386
        """See SmartClientMedium.disconnect()."""
 
 
387
        if not self._connected:
 
 
389
        self._read_from.close()
 
 
390
        self._write_to.close()
 
 
391
        self._ssh_connection.close()
 
 
392
        self._connected = False
 
 
394
    def _ensure_connection(self):
 
 
395
        """Connect this medium if not already connected."""
 
 
398
        executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
 
399
        if self._vendor is None:
 
 
400
            vendor = ssh._get_ssh_vendor()
 
 
402
            vendor = self._vendor
 
 
403
        self._ssh_connection = vendor.connect_ssh(self._username,
 
 
404
                self._password, self._host, self._port,
 
 
405
                command=[executable, 'serve', '--inet', '--directory=/',
 
 
407
        self._read_from, self._write_to = \
 
 
408
            self._ssh_connection.get_filelike_channels()
 
 
409
        self._connected = True
 
 
412
        """See SmartClientStreamMedium._flush()."""
 
 
413
        self._write_to.flush()
 
 
415
    def _read_bytes(self, count):
 
 
416
        """See SmartClientStreamMedium.read_bytes."""
 
 
417
        if not self._connected:
 
 
418
            raise errors.MediumNotConnected(self)
 
 
419
        return self._read_from.read(count)
 
 
422
class SmartTCPClientMedium(SmartClientStreamMedium):
 
 
423
    """A client medium using TCP."""
 
 
425
    def __init__(self, host, port):
 
 
426
        """Creates a client that will connect on the first use."""
 
 
427
        SmartClientStreamMedium.__init__(self)
 
 
428
        self._connected = False
 
 
433
    def _accept_bytes(self, bytes):
 
 
434
        """See SmartClientMedium.accept_bytes."""
 
 
435
        self._ensure_connection()
 
 
436
        self._socket.sendall(bytes)
 
 
438
    def disconnect(self):
 
 
439
        """See SmartClientMedium.disconnect()."""
 
 
440
        if not self._connected:
 
 
444
        self._connected = False
 
 
446
    def _ensure_connection(self):
 
 
447
        """Connect this medium if not already connected."""
 
 
450
        self._socket = socket.socket()
 
 
451
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
 
452
        result = self._socket.connect_ex((self._host, int(self._port)))
 
 
454
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
 
455
                    (self._host, self._port, os.strerror(result)))
 
 
456
        self._connected = True
 
 
459
        """See SmartClientStreamMedium._flush().
 
 
461
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
 
462
        add a means to do a flush, but that can be done in the future.
 
 
465
    def _read_bytes(self, count):
 
 
466
        """See SmartClientMedium.read_bytes."""
 
 
467
        if not self._connected:
 
 
468
            raise errors.MediumNotConnected(self)
 
 
469
        return self._socket.recv(count)
 
 
472
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
 
473
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
 
475
    def __init__(self, medium):
 
 
476
        SmartClientMediumRequest.__init__(self, medium)
 
 
477
        # check that we are safe concurrency wise. If some streams start
 
 
478
        # allowing concurrent requests - i.e. via multiplexing - then this
 
 
479
        # assert should be moved to SmartClientStreamMedium.get_request,
 
 
480
        # and the setting/unsetting of _current_request likewise moved into
 
 
481
        # that class : but its unneeded overhead for now. RBC 20060922
 
 
482
        if self._medium._current_request is not None:
 
 
483
            raise errors.TooManyConcurrentRequests(self._medium)
 
 
484
        self._medium._current_request = self
 
 
486
    def _accept_bytes(self, bytes):
 
 
487
        """See SmartClientMediumRequest._accept_bytes.
 
 
489
        This forwards to self._medium._accept_bytes because we are operating
 
 
490
        on the mediums stream.
 
 
492
        self._medium._accept_bytes(bytes)
 
 
494
    def _finished_reading(self):
 
 
495
        """See SmartClientMediumRequest._finished_reading.
 
 
497
        This clears the _current_request on self._medium to allow a new 
 
 
498
        request to be created.
 
 
500
        assert self._medium._current_request is self
 
 
501
        self._medium._current_request = None
 
 
503
    def _finished_writing(self):
 
 
504
        """See SmartClientMediumRequest._finished_writing.
 
 
506
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
 
508
        self._medium._flush()
 
 
510
    def _read_bytes(self, count):
 
 
511
        """See SmartClientMediumRequest._read_bytes.
 
 
513
        This forwards to self._medium._read_bytes because we are operating
 
 
514
        on the mediums stream.
 
 
516
        return self._medium._read_bytes(count)