/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/transport/smart.py

Seperate SmartServer{Pipe,Socket}StreamMedium out of SmartServerStreamMedium.  Use recv to make the socket server medium better.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Smart-server protocol, client and server.
 
18
 
 
19
Requests are sent as a command and list of arguments, followed by optional
 
20
bulk body data.  Responses are similarly a response and list of arguments,
 
21
followed by bulk body data. ::
 
22
 
 
23
  SEP := '\001'
 
24
    Fields are separated by Ctrl-A.
 
25
  BULK_DATA := CHUNK+ TRAILER
 
26
    Chunks can be repeated as many times as necessary.
 
27
  CHUNK := CHUNK_LEN CHUNK_BODY
 
28
  CHUNK_LEN := DIGIT+ NEWLINE
 
29
    Gives the number of bytes in the following chunk.
 
30
  CHUNK_BODY := BYTE[chunk_len]
 
31
  TRAILER := SUCCESS_TRAILER | ERROR_TRAILER
 
32
  SUCCESS_TRAILER := 'done' NEWLINE
 
33
  ERROR_TRAILER := 
 
34
 
 
35
Paths are passed across the network.  The client needs to see a namespace that
 
36
includes any repository that might need to be referenced, and the client needs
 
37
to know about a root directory beyond which it cannot ascend.
 
38
 
 
39
Servers run over ssh will typically want to be able to access any path the user 
 
40
can access.  Public servers on the other hand (which might be over http, ssh
 
41
or tcp) will typically want to restrict access to only a particular directory 
 
42
and its children, so will want to do a software virtual root at that level.
 
43
In other words they'll want to rewrite incoming paths to be under that level
 
44
(and prevent escaping using ../ tricks.)
 
45
 
 
46
URLs that include ~ should probably be passed across to the server verbatim
 
47
and the server can expand them.  This will proably not be meaningful when 
 
48
limited to a directory?
 
49
 
 
50
At the bottom level socket, pipes, HTTP server.  For sockets, we have the
 
51
idea that you have multiple requests and get have a read error because the
 
52
other side did shutdown sd send.  For pipes we have read pipe which will have a
 
53
zero read which marks end-of-file.  For HTTP server environment there is not
 
54
end-of-stream because each request coming into the server is independent.
 
55
 
 
56
So we need a wrapper around pipes and sockets to seperate out reqeusts from
 
57
substrate and this will give us a single model which is consist for HTTP,
 
58
sockets and pipes.
 
59
 
 
60
Server-side
 
61
-----------
 
62
 
 
63
 MEDIUM  (factory for protocol, reads bytes & pushes to protocol,
 
64
          uses protocol to detect end-of-request, sends written
 
65
          bytes to client) e.g. socket, pipe, HTTP request handler.
 
66
  ^
 
67
  | bytes.
 
68
  v
 
69
 
 
70
PROTOCOL  (serialisation, deserialisation)  accepts bytes for one
 
71
          request, decodes according to internal state, pushes
 
72
          structured data to handler.  accepts structured data from
 
73
          handler and encodes and writes to the medium.  factory for
 
74
          handler.
 
75
  ^
 
76
  | structured data
 
77
  v
 
78
 
 
79
HANDLER   (domain logic) accepts structured data, operates state
 
80
          machine until the request can be satisfied,
 
81
          sends structured data to the protocol.
 
82
 
 
83
 
 
84
Client-side
 
85
-----------
 
86
 
 
87
 CLIENT             domain logic, accepts domain requests, generated structured
 
88
                    data, reads structured data from responses and turns into
 
89
                    domain data.  Sends structured data to the protocol.
 
90
                    Operates state machines until the request can be delivered
 
91
                    (e.g. reading from a bundle generated in bzrlib to deliver a
 
92
                    complete request).
 
93
 
 
94
                    Possibly this should just be RemoteBzrDir, RemoteTransport,
 
95
                    ...
 
96
  ^
 
97
  | structured data
 
98
  v
 
99
 
 
100
PROTOCOL  (serialisation, deserialisation)  accepts structured data for one
 
101
          request, encodes and writes to the medium.  Reads bytes from the
 
102
          medium, decodes and allows the client to read structured data.
 
103
  ^
 
104
  | bytes.
 
105
  v
 
106
 
 
107
 MEDIUM  (accepts bytes from the protocol & delivers to the remote server.
 
108
          Allows the potocol to read bytes e.g. socket, pipe, HTTP request.
 
109
"""
 
110
 
 
111
 
 
112
# TODO: _translate_error should be on the client, not the transport because
 
113
#     error coding is wire protocol specific.
 
114
 
 
115
# TODO: A plain integer from query_version is too simple; should give some
 
116
# capabilities too?
 
117
 
 
118
# TODO: Server should probably catch exceptions within itself and send them
 
119
# back across the network.  (But shouldn't catch KeyboardInterrupt etc)
 
120
# Also needs to somehow report protocol errors like bad requests.  Need to
 
121
# consider how we'll handle error reporting, e.g. if we get halfway through a
 
122
# bulk transfer and then something goes wrong.
 
123
 
 
124
# TODO: Standard marker at start of request/response lines?
 
125
 
 
126
# TODO: Make each request and response self-validatable, e.g. with checksums.
 
127
#
 
128
# TODO: get/put objects could be changed to gradually read back the data as it
 
129
# comes across the network
 
130
#
 
131
# TODO: What should the server do if it hits an error and has to terminate?
 
132
#
 
133
# TODO: is it useful to allow multiple chunks in the bulk data?
 
134
#
 
135
# TODO: If we get an exception during transmission of bulk data we can't just
 
136
# emit the exception because it won't be seen.
 
137
#   John proposes:  I think it would be worthwhile to have a header on each
 
138
#   chunk, that indicates it is another chunk. Then you can send an 'error'
 
139
#   chunk as long as you finish the previous chunk.
 
140
#
 
141
# TODO: Clone method on Transport; should work up towards parent directory;
 
142
# unclear how this should be stored or communicated to the server... maybe
 
143
# just pass it on all relevant requests?
 
144
#
 
145
# TODO: Better name than clone() for changing between directories.  How about
 
146
# open_dir or change_dir or chdir?
 
147
#
 
148
# TODO: Is it really good to have the notion of current directory within the
 
149
# connection?  Perhaps all Transports should factor out a common connection
 
150
# from the thing that has the directory context?
 
151
#
 
152
# TODO: Pull more things common to sftp and ssh to a higher level.
 
153
#
 
154
# TODO: The server that manages a connection should be quite small and retain
 
155
# minimum state because each of the requests are supposed to be stateless.
 
156
# Then we can write another implementation that maps to http.
 
157
#
 
158
# TODO: What to do when a client connection is garbage collected?  Maybe just
 
159
# abruptly drop the connection?
 
160
#
 
161
# TODO: Server in some cases will need to restrict access to files outside of
 
162
# a particular root directory.  LocalTransport doesn't do anything to stop you
 
163
# ascending above the base directory, so we need to prevent paths
 
164
# containing '..' in either the server or transport layers.  (Also need to
 
165
# consider what happens if someone creates a symlink pointing outside the 
 
166
# directory tree...)
 
167
#
 
168
# TODO: Server should rebase absolute paths coming across the network to put
 
169
# them under the virtual root, if one is in use.  LocalTransport currently
 
170
# doesn't do that; if you give it an absolute path it just uses it.
 
171
 
172
# XXX: Arguments can't contain newlines or ascii; possibly we should e.g.
 
173
# urlescape them instead.  Indeed possibly this should just literally be
 
174
# http-over-ssh.
 
175
#
 
176
# FIXME: This transport, with several others, has imperfect handling of paths
 
177
# within urls.  It'd probably be better for ".." from a root to raise an error
 
178
# rather than return the same directory as we do at present.
 
179
#
 
180
# TODO: Rather than working at the Transport layer we want a Branch,
 
181
# Repository or BzrDir objects that talk to a server.
 
182
#
 
183
# TODO: Probably want some way for server commands to gradually produce body
 
184
# data rather than passing it as a string; they could perhaps pass an
 
185
# iterator-like callback that will gradually yield data; it probably needs a
 
186
# close() method that will always be closed to do any necessary cleanup.
 
187
#
 
188
# TODO: Split the actual smart server from the ssh encoding of it.
 
189
#
 
190
# TODO: Perhaps support file-level readwrite operations over the transport
 
191
# too.
 
192
#
 
193
# TODO: SmartBzrDir class, proxying all Branch etc methods across to another
 
194
# branch doing file-level operations.
 
195
#
 
196
# TODO: jam 20060915 _decode_tuple is acting directly on input over
 
197
#       the socket, and it assumes everything is UTF8 sections separated
 
198
#       by \001. Which means a request like '\002' Will abort the connection
 
199
#       because of a UnicodeDecodeError. It does look like invalid data will
 
200
#       kill the SmartServerStreamMedium, but only with an abort + exception, and 
 
201
#       the overall server shouldn't die.
 
202
 
 
203
from cStringIO import StringIO
 
204
import os
 
205
import select
 
206
import socket
 
207
import tempfile
 
208
import threading
 
209
import urllib
 
210
import urlparse
 
211
 
 
212
from bzrlib import (
 
213
    bzrdir,
 
214
    errors,
 
215
    revision,
 
216
    transport,
 
217
    trace,
 
218
    urlutils,
 
219
    )
 
220
from bzrlib.bundle.serializer import write_bundle
 
221
 
 
222
# must do this otherwise urllib can't parse the urls properly :(
 
223
for scheme in ['ssh', 'bzr', 'bzr+loopback', 'bzr+ssh']:
 
224
    transport.register_urlparse_netloc_protocol(scheme)
 
225
del scheme
 
226
 
 
227
 
 
228
def _recv_tuple(from_file):
 
229
    req_line = from_file.readline()
 
230
    return _decode_tuple(req_line)
 
231
 
 
232
 
 
233
def _decode_tuple(req_line):
 
234
    if req_line == None or req_line == '':
 
235
        return None
 
236
    if req_line[-1] != '\n':
 
237
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
238
    return tuple((a.decode('utf-8') for a in req_line[:-1].split('\x01')))
 
239
 
 
240
 
 
241
def _encode_tuple(args):
 
242
    """Encode the tuple args to a bytestream."""
 
243
    return '\x01'.join((a.encode('utf-8') for a in args)) + '\n'
 
244
 
 
245
 
 
246
class SmartProtocolBase(object):
 
247
    """Methods common to client and server"""
 
248
 
 
249
    # TODO: this only actually accomodates a single block; possibly should support
 
250
    # multiple chunks?
 
251
    def _recv_bulk(self):
 
252
        # This is OBSOLETE except for the double handline in the server: 
 
253
        # the read_bulk + reencode noise.
 
254
        chunk_len = self._in.readline()
 
255
        try:
 
256
            chunk_len = int(chunk_len)
 
257
        except ValueError:
 
258
            raise errors.SmartProtocolError("bad chunk length line %r" % chunk_len)
 
259
        bulk = self._in.read(chunk_len)
 
260
        if len(bulk) != chunk_len:
 
261
            raise errors.SmartProtocolError("short read fetching bulk data chunk")
 
262
        self._recv_trailer()
 
263
        return bulk
 
264
 
 
265
    def _recv_trailer(self):
 
266
        resp = self._recv_tuple()
 
267
        if resp == ('done', ):
 
268
            return
 
269
        else:
 
270
            self._translate_error(resp)
 
271
 
 
272
    def _encode_bulk_data(self, body):
 
273
        """Encode body as a bulk data chunk."""
 
274
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
275
 
 
276
    def _serialise_offsets(self, offsets):
 
277
        """Serialise a readv offset list."""
 
278
        txt = []
 
279
        for start, length in offsets:
 
280
            txt.append('%d,%d' % (start, length))
 
281
        return '\n'.join(txt)
 
282
 
 
283
    def _send_bulk_data(self, body, a_file=None):
 
284
        """Send chunked body data"""
 
285
        assert isinstance(body, str)
 
286
        bytes = self._encode_bulk_data(body)
 
287
        self._write_and_flush(bytes, a_file)
 
288
 
 
289
    def _write_and_flush(self, bytes, a_file=None):
 
290
        """Write bytes to self._out and flush it."""
 
291
        # XXX: this will be inefficient.  Just ask Robert.
 
292
        if a_file is None:
 
293
            a_file = self._out
 
294
        a_file.write(bytes)
 
295
        a_file.flush()
 
296
        
 
297
 
 
298
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
299
    """Server-side encoding and decoding logic for smart version 1."""
 
300
    
 
301
    def __init__(self, output_stream, backing_transport):
 
302
        self._out_stream = output_stream
 
303
        self._backing_transport = backing_transport
 
304
        self.excess_buffer = ''
 
305
        self.finished_reading = False
 
306
        self.in_buffer = ''
 
307
        self.has_dispatched = False
 
308
        self.request = None
 
309
        self._body_decoder = None
 
310
 
 
311
    def accept_bytes(self, bytes):
 
312
        """Take bytes, and advance the internal state machine appropriately.
 
313
        
 
314
        :param bytes: must be a byte string
 
315
        """
 
316
        assert isinstance(bytes, str)
 
317
        self.in_buffer += bytes
 
318
        if not self.has_dispatched:
 
319
            if '\n' not in self.in_buffer:
 
320
                # no command line yet
 
321
                return
 
322
            self.has_dispatched = True
 
323
            # XXX if in_buffer not \n-terminated this will do the wrong
 
324
            # thing.
 
325
            try:
 
326
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
327
                first_line += '\n'
 
328
                req_args = _decode_tuple(first_line)
 
329
                self.request = SmartServerRequestHandler(
 
330
                    self._backing_transport)
 
331
                self.request.dispatch_command(req_args[0], req_args[1:])
 
332
                if self.request.finished_reading:
 
333
                    # trivial request
 
334
                    self.excess_buffer = self.in_buffer
 
335
                    self.in_buffer = ''
 
336
                    self._send_response(self.request.response.args,
 
337
                        self.request.response.body)
 
338
                self.sync_with_request(self.request)
 
339
            except KeyboardInterrupt:
 
340
                raise
 
341
            except Exception, exception:
 
342
                # everything else: pass to client, flush, and quit
 
343
                self._send_response(('error', str(exception)))
 
344
                return None
 
345
 
 
346
        if self.has_dispatched:
 
347
            if self.finished_reading:
 
348
                # nothing to do.XXX: this routine should be a single state 
 
349
                # machine too.
 
350
                self.excess_buffer += self.in_buffer
 
351
                self.in_buffer = ''
 
352
                return
 
353
            if self._body_decoder is None:
 
354
                self._body_decoder = LengthPrefixedBodyDecoder()
 
355
            self._body_decoder.accept_bytes(self.in_buffer)
 
356
            self.in_buffer = self._body_decoder.unused_data
 
357
            body_data = self._body_decoder.read_pending_data()
 
358
            self.request.accept_body(body_data)
 
359
            if self._body_decoder.finished_reading:
 
360
                self.request.end_of_body()
 
361
                assert self.request.finished_reading, \
 
362
                    "no more body, request not finished"
 
363
            self.sync_with_request(self.request)
 
364
            if self.request.response is not None:
 
365
                self._send_response(self.request.response.args,
 
366
                    self.request.response.body)
 
367
                self.excess_buffer = self.in_buffer
 
368
                self.in_buffer = ''
 
369
            else:
 
370
                assert not self.request.finished_reading, \
 
371
                    "no response and we have finished reading."
 
372
 
 
373
    def _send_response(self, args, body=None):
 
374
        """Send a smart server response down the output stream."""
 
375
        self._out_stream.write(_encode_tuple(args))
 
376
        if body is None:
 
377
            self._out_stream.flush()
 
378
        else:
 
379
            self._send_bulk_data(body, self._out_stream)
 
380
            #self._out_stream.write('BLARGH')
 
381
 
 
382
    def sync_with_request(self, request):
 
383
        self.finished_reading = request.finished_reading
 
384
        
 
385
 
 
386
class LengthPrefixedBodyDecoder(object):
 
387
    """Decodes the length-prefixed bulk data."""
 
388
    
 
389
    def __init__(self):
 
390
        self.bytes_left = None
 
391
        self.finished_reading = False
 
392
        self.unused_data = ''
 
393
        self.state_accept = self._state_accept_expecting_length
 
394
        self.state_read = self._state_read_no_data
 
395
        self._in_buffer = ''
 
396
        self._trailer_buffer = ''
 
397
    
 
398
    def accept_bytes(self, bytes):
 
399
        """Decode as much of bytes as possible.
 
400
 
 
401
        If 'bytes' contains too much data it will be appended to
 
402
        self.unused_data.
 
403
 
 
404
        finished_reading will be set when no more data is required.  Further
 
405
        data will be appended to self.unused_data.
 
406
        """
 
407
        # accept_bytes is allowed to change the state
 
408
        current_state = self.state_accept
 
409
        self.state_accept(bytes)
 
410
        while current_state != self.state_accept:
 
411
            current_state = self.state_accept
 
412
            self.state_accept('')
 
413
 
 
414
    def next_read_size(self):
 
415
        if self.bytes_left is not None:
 
416
            # Ideally we want to read all the remainder of the body and the
 
417
            # trailer in one go.
 
418
            return self.bytes_left + 5
 
419
        elif self.state_accept == self._state_accept_reading_trailer:
 
420
            # Just the trailer left
 
421
            return 5 - len(self._trailer_buffer)
 
422
        elif self.state_accept == self._state_accept_expecting_length:
 
423
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
424
            # 'done\n').
 
425
            return 6
 
426
        else:
 
427
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
428
            return 1
 
429
        
 
430
    def read_pending_data(self):
 
431
        """Return any pending data that has been decoded."""
 
432
        return self.state_read()
 
433
 
 
434
    def _state_accept_expecting_length(self, bytes):
 
435
        self._in_buffer += bytes
 
436
        pos = self._in_buffer.find('\n')
 
437
        if pos == -1:
 
438
            return
 
439
        self.bytes_left = int(self._in_buffer[:pos])
 
440
        self._in_buffer = self._in_buffer[pos+1:]
 
441
        self.bytes_left -= len(self._in_buffer)
 
442
        self.state_accept = self._state_accept_reading_body
 
443
        self.state_read = self._state_read_in_buffer
 
444
 
 
445
    def _state_accept_reading_body(self, bytes):
 
446
        self._in_buffer += bytes
 
447
        self.bytes_left -= len(bytes)
 
448
        if self.bytes_left <= 0:
 
449
            # Finished with body
 
450
            if self.bytes_left != 0:
 
451
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
 
452
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
453
            self.bytes_left = None
 
454
            self.state_accept = self._state_accept_reading_trailer
 
455
        
 
456
    def _state_accept_reading_trailer(self, bytes):
 
457
        self._trailer_buffer += bytes
 
458
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
459
        # a ProtocolViolation exception?
 
460
        if self._trailer_buffer.startswith('done\n'):
 
461
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
462
            self.state_accept = self._state_accept_reading_unused
 
463
            self.finished_reading = True
 
464
    
 
465
    def _state_accept_reading_unused(self, bytes):
 
466
        self.unused_data += bytes
 
467
 
 
468
    def _state_read_no_data(self):
 
469
        return ''
 
470
 
 
471
    def _state_read_in_buffer(self):
 
472
        result = self._in_buffer
 
473
        self._in_buffer = ''
 
474
        return result
 
475
 
 
476
 
 
477
class SmartServerStreamMedium(SmartProtocolBase):
 
478
    """Handles smart commands coming over a stream.
 
479
 
 
480
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
481
    in-process fifo for testing.
 
482
 
 
483
    One instance is created for each connected client; it can serve multiple
 
484
    requests in the lifetime of the connection.
 
485
 
 
486
    The server passes requests through to an underlying backing transport, 
 
487
    which will typically be a LocalTransport looking at the server's filesystem.
 
488
    """
 
489
 
 
490
    def __init__(self, in_file, out_file, backing_transport):
 
491
        """Construct new server.
 
492
 
 
493
        :param in_file: Python file from which requests can be read.
 
494
        :param out_file: Python file to write responses.
 
495
        :param backing_transport: Transport for the directory served.
 
496
        """
 
497
        self._in = in_file
 
498
        self._out = out_file
 
499
        self.backing_transport = backing_transport
 
500
 
 
501
    def _recv_tuple(self):
 
502
        """Read a request from the client and return as a tuple.
 
503
        
 
504
        Returns None at end of file (if the client closed the connection.)
 
505
        """
 
506
        # ** Deserialise and read bytes
 
507
        return _recv_tuple(self._in)
 
508
 
 
509
    def _send_tuple(self, args):
 
510
        """Send response header"""
 
511
        # ** serialise and write bytes
 
512
        return self._write_and_flush(_encode_tuple(args))
 
513
 
 
514
    def _send_error_and_disconnect(self, exception):
 
515
        # ** serialise and write bytes
 
516
        self._send_tuple(('error', str(exception)))
 
517
        ## self._out.close()
 
518
        ## self._in.close()
 
519
 
 
520
    def serve(self):
 
521
        """Serve requests until the client disconnects."""
 
522
        # Keep a reference to stderr because the sys module's globals get set to
 
523
        # None during interpreter shutdown.
 
524
        from sys import stderr
 
525
        try:
 
526
            while True:
 
527
                protocol = SmartServerRequestProtocolOne(self._out,
 
528
                                                         self.backing_transport)
 
529
                if self._serve_one_request(protocol) == False:
 
530
                    break
 
531
        except Exception, e:
 
532
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
533
            raise
 
534
 
 
535
 
 
536
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
537
 
 
538
    def __init__(self, in_socket, out_file, backing_transport):
 
539
        """Constructor.
 
540
 
 
541
        :param in_socket: the socket the server will read from.  It will be put
 
542
            into blocking mode.
 
543
        """
 
544
        in_socket.setblocking(True)
 
545
        SmartServerStreamMedium.__init__(
 
546
            self, in_socket, out_file, backing_transport)
 
547
        self.push_back = ''
 
548
        
 
549
    def _serve_one_request(self, protocol):
 
550
        """Read one request from input, process, send back a response.
 
551
        
 
552
        :param protocol: a SmartServerRequestProtocol.
 
553
        :return: False if the server should terminate, otherwise None.
 
554
        """
 
555
        while not protocol.finished_reading:
 
556
            if self.push_back:
 
557
                protocol.accept_bytes(self.push_back)
 
558
                self.push_back = ''
 
559
            else:
 
560
                bytes = self._in.recv(4096)
 
561
                if bytes == '':
 
562
                    return False
 
563
                protocol.accept_bytes(bytes)
 
564
 
 
565
        self.push_back = protocol.excess_buffer
 
566
    
 
567
 
 
568
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
569
 
 
570
    def __init__(self, in_file, out_file, backing_transport):
 
571
        """Construct new server.
 
572
 
 
573
        :param in_file: Python file from which requests can be read.
 
574
        :param out_file: Python file to write responses.
 
575
        :param backing_transport: Transport for the directory served.
 
576
        """
 
577
        SmartServerStreamMedium.__init__(self, in_file, out_file, backing_transport)
 
578
        self._in = in_file
 
579
        self._out = out_file
 
580
 
 
581
    def _serve_one_request(self, protocol):
 
582
        """Read one request from input, process, send back a response.
 
583
        
 
584
        :param protocol: a SmartServerRequestProtocol.
 
585
        :return: False if the server should terminate, otherwise None.
 
586
        """
 
587
        # ** deserialise, read bytes, serialise and write bytes
 
588
        req_line = self._in.readline()
 
589
        # this should just test "req_line == ''", surely?  -- Andrew Bennetts
 
590
        if req_line in ('', None):
 
591
            # client closed connection
 
592
            return False  # shutdown server
 
593
        try:
 
594
            protocol.accept_bytes(req_line)
 
595
            if not protocol.finished_reading:
 
596
                # this boils down to readline which wont block on open sockets
 
597
                # without data. We should really though read as much as is
 
598
                # available and then hand to that accept_bytes without this
 
599
                # silly double-decode.
 
600
                bulk = self._recv_bulk()
 
601
                bulk_bytes = ''.join(('%d\n' % len(bulk), bulk, 'done\n'))
 
602
                protocol.accept_bytes(bulk_bytes)
 
603
                # might be nice to do protocol.end_of_bytes()
 
604
                # because self._recv_bulk reads all the bytes, this must finish
 
605
                # after one delivery of data rather than looping.
 
606
                assert protocol.finished_reading, 'was not finished reading'
 
607
        except KeyboardInterrupt:
 
608
            raise
 
609
        except Exception, e:
 
610
            # everything else: pass to client, flush, and quit
 
611
            self._send_error_and_disconnect(e)
 
612
            return False
 
613
 
 
614
 
 
615
class SmartServerResponse(object):
 
616
    """Response generated by SmartServerRequestHandler."""
 
617
 
 
618
    def __init__(self, args, body=None):
 
619
        self.args = args
 
620
        self.body = body
 
621
 
 
622
# XXX: TODO: Create a SmartServerRequestHandler which will take the responsibility
 
623
# for delivering the data for a request. This could be done with as the
 
624
# StreamServer, though that would create conflation between request and response
 
625
# which may be undesirable.
 
626
 
 
627
 
 
628
class SmartServerRequestHandler(object):
 
629
    """Protocol logic for smart server.
 
630
    
 
631
    This doesn't handle serialization at all, it just processes requests and
 
632
    creates responses.
 
633
    """
 
634
 
 
635
    # IMPORTANT FOR IMPLEMENTORS: It is important that SmartServerRequestHandler
 
636
    # not contain encoding or decoding logic to allow the wire protocol to vary
 
637
    # from the object protocol: we will want to tweak the wire protocol separate
 
638
    # from the object model, and ideally we will be able to do that without
 
639
    # having a SmartServerRequestHandler subclass for each wire protocol, rather
 
640
    # just a Protocol subclass.
 
641
 
 
642
    # TODO: Better way of representing the body for commands that take it,
 
643
    # and allow it to be streamed into the server.
 
644
    
 
645
    def __init__(self, backing_transport):
 
646
        self._backing_transport = backing_transport
 
647
        self._converted_command = False
 
648
        self.finished_reading = False
 
649
        self._body_bytes = ''
 
650
        self.response = None
 
651
 
 
652
    def accept_body(self, bytes):
 
653
        """Accept body data.
 
654
 
 
655
        This should be overriden for each command that desired body data to
 
656
        handle the right format of that data. I.e. plain bytes, a bundle etc.
 
657
 
 
658
        The deserialisation into that format should be done in the Protocol
 
659
        object. Set self.desired_body_format to the format your method will
 
660
        handle.
 
661
        """
 
662
        # default fallback is to accumulate bytes.
 
663
        self._body_bytes += bytes
 
664
        
 
665
    def _end_of_body_handler(self):
 
666
        """An unimplemented end of body handler."""
 
667
        raise NotImplementedError(self._end_of_body_handler)
 
668
        
 
669
    def do_hello(self):
 
670
        """Answer a version request with my version."""
 
671
        return SmartServerResponse(('ok', '1'))
 
672
 
 
673
    def do_has(self, relpath):
 
674
        r = self._backing_transport.has(relpath) and 'yes' or 'no'
 
675
        return SmartServerResponse((r,))
 
676
 
 
677
    def do_get(self, relpath):
 
678
        backing_bytes = self._backing_transport.get_bytes(relpath)
 
679
        return SmartServerResponse(('ok',), backing_bytes)
 
680
 
 
681
    def _deserialise_optional_mode(self, mode):
 
682
        # XXX: FIXME this should be on the protocol object.
 
683
        if mode == '':
 
684
            return None
 
685
        else:
 
686
            return int(mode)
 
687
 
 
688
    def do_append(self, relpath, mode):
 
689
        self._converted_command = True
 
690
        self._relpath = relpath
 
691
        self._mode = self._deserialise_optional_mode(mode)
 
692
        self._end_of_body_handler = self._handle_do_append_end
 
693
    
 
694
    def _handle_do_append_end(self):
 
695
        old_length = self._backing_transport.append_bytes(
 
696
            self._relpath, self._body_bytes, self._mode)
 
697
        self.response = SmartServerResponse(('appended', '%d' % old_length))
 
698
 
 
699
    def do_delete(self, relpath):
 
700
        self._backing_transport.delete(relpath)
 
701
 
 
702
    def do_iter_files_recursive(self, abspath):
 
703
        # XXX: the path handling needs some thought.
 
704
        #relpath = self._backing_transport.relpath(abspath)
 
705
        transport = self._backing_transport.clone(abspath)
 
706
        filenames = transport.iter_files_recursive()
 
707
        return SmartServerResponse(('names',) + tuple(filenames))
 
708
 
 
709
    def do_list_dir(self, relpath):
 
710
        filenames = self._backing_transport.list_dir(relpath)
 
711
        return SmartServerResponse(('names',) + tuple(filenames))
 
712
 
 
713
    def do_mkdir(self, relpath, mode):
 
714
        self._backing_transport.mkdir(relpath,
 
715
                                      self._deserialise_optional_mode(mode))
 
716
 
 
717
    def do_move(self, rel_from, rel_to):
 
718
        self._backing_transport.move(rel_from, rel_to)
 
719
 
 
720
    def do_put(self, relpath, mode):
 
721
        self._converted_command = True
 
722
        self._relpath = relpath
 
723
        self._mode = self._deserialise_optional_mode(mode)
 
724
        self._end_of_body_handler = self._handle_do_put
 
725
 
 
726
    def _handle_do_put(self):
 
727
        self._backing_transport.put_bytes(self._relpath,
 
728
                self._body_bytes, self._mode)
 
729
        self.response = SmartServerResponse(('ok',))
 
730
 
 
731
    def _deserialise_offsets(self, text):
 
732
        # XXX: FIXME this should be on the protocol object.
 
733
        offsets = []
 
734
        for line in text.split('\n'):
 
735
            if not line:
 
736
                continue
 
737
            start, length = line.split(',')
 
738
            offsets.append((int(start), int(length)))
 
739
        return offsets
 
740
 
 
741
    def do_put_non_atomic(self, relpath, mode, create_parent, dir_mode):
 
742
        self._converted_command = True
 
743
        self._end_of_body_handler = self._handle_put_non_atomic
 
744
        self._relpath = relpath
 
745
        self._dir_mode = self._deserialise_optional_mode(dir_mode)
 
746
        self._mode = self._deserialise_optional_mode(mode)
 
747
        # a boolean would be nicer XXX
 
748
        self._create_parent = (create_parent == 'T')
 
749
 
 
750
    def _handle_put_non_atomic(self):
 
751
        self._backing_transport.put_bytes_non_atomic(self._relpath,
 
752
                self._body_bytes,
 
753
                mode=self._mode,
 
754
                create_parent_dir=self._create_parent,
 
755
                dir_mode=self._dir_mode)
 
756
        self.response = SmartServerResponse(('ok',))
 
757
 
 
758
    def do_readv(self, relpath):
 
759
        self._converted_command = True
 
760
        self._end_of_body_handler = self._handle_readv_offsets
 
761
        self._relpath = relpath
 
762
 
 
763
    def end_of_body(self):
 
764
        """No more body data will be received."""
 
765
        self._run_handler_code(self._end_of_body_handler, (), {})
 
766
        # cannot read after this.
 
767
        self.finished_reading = True
 
768
 
 
769
    def _handle_readv_offsets(self):
 
770
        """accept offsets for a readv request."""
 
771
        offsets = self._deserialise_offsets(self._body_bytes)
 
772
        backing_bytes = ''.join(bytes for offset, bytes in
 
773
            self._backing_transport.readv(self._relpath, offsets))
 
774
        self.response = SmartServerResponse(('readv',), backing_bytes)
 
775
        
 
776
    def do_rename(self, rel_from, rel_to):
 
777
        self._backing_transport.rename(rel_from, rel_to)
 
778
 
 
779
    def do_rmdir(self, relpath):
 
780
        self._backing_transport.rmdir(relpath)
 
781
 
 
782
    def do_stat(self, relpath):
 
783
        stat = self._backing_transport.stat(relpath)
 
784
        return SmartServerResponse(('stat', str(stat.st_size), oct(stat.st_mode)))
 
785
        
 
786
    def do_get_bundle(self, path, revision_id):
 
787
        # open transport relative to our base
 
788
        t = self._backing_transport.clone(path)
 
789
        control, extra_path = bzrdir.BzrDir.open_containing_from_transport(t)
 
790
        repo = control.open_repository()
 
791
        tmpf = tempfile.TemporaryFile()
 
792
        base_revision = revision.NULL_REVISION
 
793
        write_bundle(repo, revision_id, base_revision, tmpf)
 
794
        tmpf.seek(0)
 
795
        return SmartServerResponse((), tmpf.read())
 
796
 
 
797
    def dispatch_command(self, cmd, args):
 
798
        """Deprecated compatibility method.""" # XXX XXX
 
799
        func = getattr(self, 'do_' + cmd, None)
 
800
        if func is None:
 
801
            raise errors.SmartProtocolError("bad request %r" % (cmd,))
 
802
        self._run_handler_code(func, args, {})
 
803
 
 
804
    def _run_handler_code(self, callable, args, kwargs):
 
805
        """Run some handler specific code 'callable'.
 
806
 
 
807
        If a result is returned, it is considered to be the commands response,
 
808
        and finished_reading is set true, and its assigned to self.response.
 
809
 
 
810
        Any exceptions caught are translated and a response object created
 
811
        from them.
 
812
        """
 
813
        result = self._call_converting_errors(callable, args, kwargs)
 
814
        if result is not None:
 
815
            self.response = result
 
816
            self.finished_reading = True
 
817
        # handle unconverted commands
 
818
        if not self._converted_command:
 
819
            self.finished_reading = True
 
820
            if result is None:
 
821
                self.response = SmartServerResponse(('ok',))
 
822
 
 
823
    def _call_converting_errors(self, callable, args, kwargs):
 
824
        """Call callable converting errors to Response objects."""
 
825
        try:
 
826
            return callable(*args, **kwargs)
 
827
        except errors.NoSuchFile, e:
 
828
            return SmartServerResponse(('NoSuchFile', e.path))
 
829
        except errors.FileExists, e:
 
830
            return SmartServerResponse(('FileExists', e.path))
 
831
        except errors.DirectoryNotEmpty, e:
 
832
            return SmartServerResponse(('DirectoryNotEmpty', e.path))
 
833
        except errors.ShortReadvError, e:
 
834
            return SmartServerResponse(('ShortReadvError',
 
835
                e.path, str(e.offset), str(e.length), str(e.actual)))
 
836
        except UnicodeError, e:
 
837
            # If it is a DecodeError, than most likely we are starting
 
838
            # with a plain string
 
839
            str_or_unicode = e.object
 
840
            if isinstance(str_or_unicode, unicode):
 
841
                val = u'u:' + str_or_unicode
 
842
            else:
 
843
                val = u's:' + str_or_unicode.encode('base64')
 
844
            # This handles UnicodeEncodeError or UnicodeDecodeError
 
845
            return SmartServerResponse((e.__class__.__name__,
 
846
                    e.encoding, val, str(e.start), str(e.end), e.reason))
 
847
        except errors.TransportNotPossible, e:
 
848
            if e.msg == "readonly transport":
 
849
                return SmartServerResponse(('ReadOnlyError', ))
 
850
            else:
 
851
                raise
 
852
 
 
853
 
 
854
class SmartTCPServer(object):
 
855
    """Listens on a TCP socket and accepts connections from smart clients"""
 
856
 
 
857
    def __init__(self, backing_transport, host='127.0.0.1', port=0):
 
858
        """Construct a new server.
 
859
 
 
860
        To actually start it running, call either start_background_thread or
 
861
        serve.
 
862
 
 
863
        :param host: Name of the interface to listen on.
 
864
        :param port: TCP port to listen on, or 0 to allocate a transient port.
 
865
        """
 
866
        self._server_socket = socket.socket()
 
867
        self._server_socket.bind((host, port))
 
868
        self.port = self._server_socket.getsockname()[1]
 
869
        self._server_socket.listen(1)
 
870
        self._server_socket.settimeout(1)
 
871
        self.backing_transport = backing_transport
 
872
 
 
873
    def serve(self):
 
874
        # let connections timeout so that we get a chance to terminate
 
875
        # Keep a reference to the exceptions we want to catch because the socket
 
876
        # module's globals get set to None during interpreter shutdown.
 
877
        from socket import timeout as socket_timeout
 
878
        from socket import error as socket_error
 
879
        self._should_terminate = False
 
880
        while not self._should_terminate:
 
881
            try:
 
882
                self.accept_and_serve()
 
883
            except socket_timeout:
 
884
                # just check if we're asked to stop
 
885
                pass
 
886
            except socket_error, e:
 
887
                trace.warning("client disconnected: %s", e)
 
888
                pass
 
889
 
 
890
    def get_url(self):
 
891
        """Return the url of the server"""
 
892
        return "bzr://%s:%d/" % self._server_socket.getsockname()
 
893
 
 
894
    def accept_and_serve(self):
 
895
        conn, client_addr = self._server_socket.accept()
 
896
        # For WIN32, where the timeout value from the listening socket
 
897
        # propogates to the newly accepted socket.
 
898
        conn.setblocking(True)
 
899
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
900
        from_client = conn.makefile('r')
 
901
        to_client = conn.makefile('w')
 
902
        handler = SmartServerSocketStreamMedium(conn, to_client,
 
903
                self.backing_transport)
 
904
        connection_thread = threading.Thread(None, handler.serve, name='smart-server-child')
 
905
        connection_thread.setDaemon(True)
 
906
        connection_thread.start()
 
907
 
 
908
    def start_background_thread(self):
 
909
        self._server_thread = threading.Thread(None,
 
910
                self.serve,
 
911
                name='server-' + self.get_url())
 
912
        self._server_thread.setDaemon(True)
 
913
        self._server_thread.start()
 
914
 
 
915
    def stop_background_thread(self):
 
916
        self._should_terminate = True
 
917
        # self._server_socket.close()
 
918
        # we used to join the thread, but it's not really necessary; it will
 
919
        # terminate in time
 
920
        ## self._server_thread.join()
 
921
 
 
922
 
 
923
class SmartTCPServer_for_testing(SmartTCPServer):
 
924
    """Server suitable for use by transport tests.
 
925
    
 
926
    This server is backed by the process's cwd.
 
927
    """
 
928
 
 
929
    def __init__(self):
 
930
        self._homedir = os.getcwd()
 
931
        # The server is set up by default like for ssh access: the client
 
932
        # passes filesystem-absolute paths; therefore the server must look
 
933
        # them up relative to the root directory.  it might be better to act
 
934
        # a public server and have the server rewrite paths into the test
 
935
        # directory.
 
936
        SmartTCPServer.__init__(self, transport.get_transport("file:///"))
 
937
        
 
938
    def setUp(self):
 
939
        """Set up server for testing"""
 
940
        self.start_background_thread()
 
941
 
 
942
    def tearDown(self):
 
943
        self.stop_background_thread()
 
944
 
 
945
    def get_url(self):
 
946
        """Return the url of the server"""
 
947
        host, port = self._server_socket.getsockname()
 
948
        # XXX: I think this is likely to break on windows -- self._homedir will
 
949
        # have backslashes (and maybe a drive letter?).
 
950
        #  -- Andrew Bennetts, 2006-08-29
 
951
        return "bzr://%s:%d%s" % (host, port, urlutils.escape(self._homedir))
 
952
 
 
953
    def get_bogus_url(self):
 
954
        """Return a URL which will fail to connect"""
 
955
        return 'bzr://127.0.0.1:1/'
 
956
 
 
957
 
 
958
class SmartStat(object):
 
959
 
 
960
    def __init__(self, size, mode):
 
961
        self.st_size = size
 
962
        self.st_mode = mode
 
963
 
 
964
 
 
965
class SmartTransport(transport.Transport):
 
966
    """Connection to a smart server.
 
967
 
 
968
    The connection holds references to pipes that can be used to send requests
 
969
    to the server.
 
970
 
 
971
    The connection has a notion of the current directory to which it's
 
972
    connected; this is incorporated in filenames passed to the server.
 
973
    
 
974
    This supports some higher-level RPC operations and can also be treated 
 
975
    like a Transport to do file-like operations.
 
976
 
 
977
    The connection can be made over a tcp socket, or (in future) an ssh pipe
 
978
    or a series of http requests.  There are concrete subclasses for each
 
979
    type: SmartTCPTransport, etc.
 
980
    """
 
981
 
 
982
    # IMPORTANT FOR IMPLEMENTORS: SmartTransport MUST NOT be given encoding
 
983
    # responsibilities: Put those on SmartClient or similar. This is vital for
 
984
    # the ability to support multiple versions of the smart protocol over time:
 
985
    # SmartTransport is an adapter from the Transport object model to the 
 
986
    # SmartClient model, not an encoder.
 
987
 
 
988
    def __init__(self, url, clone_from=None, medium=None):
 
989
        """Constructor.
 
990
 
 
991
        :param medium: The medium to use for this RemoteTransport. This must be
 
992
            supplied if clone_from is None.
 
993
        """
 
994
        ### Technically super() here is faulty because Transport's __init__
 
995
        ### fails to take 2 parameters, and if super were to choose a silly
 
996
        ### initialisation order things would blow up. 
 
997
        if not url.endswith('/'):
 
998
            url += '/'
 
999
        super(SmartTransport, self).__init__(url)
 
1000
        self._scheme, self._username, self._password, self._host, self._port, self._path = \
 
1001
                transport.split_url(url)
 
1002
        if clone_from is None:
 
1003
            self._medium = medium
 
1004
        else:
 
1005
            # credentials may be stripped from the base in some circumstances
 
1006
            # as yet to be clearly defined or documented, so copy them.
 
1007
            self._username = clone_from._username
 
1008
            # reuse same connection
 
1009
            self._medium = clone_from._medium
 
1010
        assert self._medium is not None
 
1011
 
 
1012
    def abspath(self, relpath):
 
1013
        """Return the full url to the given relative path.
 
1014
        
 
1015
        @param relpath: the relative path or path components
 
1016
        @type relpath: str or list
 
1017
        """
 
1018
        return self._unparse_url(self._remote_path(relpath))
 
1019
    
 
1020
    def clone(self, relative_url):
 
1021
        """Make a new SmartTransport related to me, sharing the same connection.
 
1022
 
 
1023
        This essentially opens a handle on a different remote directory.
 
1024
        """
 
1025
        if relative_url is None:
 
1026
            return SmartTransport(self.base, self)
 
1027
        else:
 
1028
            return SmartTransport(self.abspath(relative_url), self)
 
1029
 
 
1030
    def is_readonly(self):
 
1031
        """Smart server transport can do read/write file operations."""
 
1032
        return False
 
1033
                                                   
 
1034
    def get_smart_client(self):
 
1035
        return self._medium
 
1036
 
 
1037
    def get_smart_medium(self):
 
1038
        return self._medium
 
1039
                                                   
 
1040
    def _unparse_url(self, path):
 
1041
        """Return URL for a path.
 
1042
 
 
1043
        :see: SFTPUrlHandling._unparse_url
 
1044
        """
 
1045
        # TODO: Eventually it should be possible to unify this with
 
1046
        # SFTPUrlHandling._unparse_url?
 
1047
        if path == '':
 
1048
            path = '/'
 
1049
        path = urllib.quote(path)
 
1050
        netloc = urllib.quote(self._host)
 
1051
        if self._username is not None:
 
1052
            netloc = '%s@%s' % (urllib.quote(self._username), netloc)
 
1053
        if self._port is not None:
 
1054
            netloc = '%s:%d' % (netloc, self._port)
 
1055
        return urlparse.urlunparse((self._scheme, netloc, path, '', '', ''))
 
1056
 
 
1057
    def _remote_path(self, relpath):
 
1058
        """Returns the Unicode version of the absolute path for relpath."""
 
1059
        return self._combine_paths(self._path, relpath)
 
1060
 
 
1061
    def _call(self, method, *args):
 
1062
        resp = self._call2(method, *args)
 
1063
        self._translate_error(resp)
 
1064
 
 
1065
    def _call2(self, method, *args):
 
1066
        """Call a method on the remote server."""
 
1067
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1068
        protocol.call(method, *args)
 
1069
        return protocol.read_response_tuple()
 
1070
 
 
1071
    def _call_with_body_bytes(self, method, args, body):
 
1072
        """Call a method on the remote server with body bytes."""
 
1073
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1074
        protocol.call_with_body_bytes((method, ) + args, body)
 
1075
        return protocol.read_response_tuple()
 
1076
 
 
1077
    def has(self, relpath):
 
1078
        """Indicate whether a remote file of the given name exists or not.
 
1079
 
 
1080
        :see: Transport.has()
 
1081
        """
 
1082
        resp = self._call2('has', self._remote_path(relpath))
 
1083
        if resp == ('yes', ):
 
1084
            return True
 
1085
        elif resp == ('no', ):
 
1086
            return False
 
1087
        else:
 
1088
            self._translate_error(resp)
 
1089
 
 
1090
    def get(self, relpath):
 
1091
        """Return file-like object reading the contents of a remote file.
 
1092
        
 
1093
        :see: Transport.get_bytes()/get_file()
 
1094
        """
 
1095
        return StringIO(self.get_bytes(relpath))
 
1096
 
 
1097
    def get_bytes(self, relpath):
 
1098
        remote = self._remote_path(relpath)
 
1099
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1100
        protocol.call('get', remote)
 
1101
        resp = protocol.read_response_tuple(True)
 
1102
        if resp != ('ok', ):
 
1103
            protocol.cancel_read_body()
 
1104
            self._translate_error(resp, relpath)
 
1105
        return protocol.read_body_bytes()
 
1106
 
 
1107
    def _serialise_optional_mode(self, mode):
 
1108
        if mode is None:
 
1109
            return ''
 
1110
        else:
 
1111
            return '%d' % mode
 
1112
 
 
1113
    def mkdir(self, relpath, mode=None):
 
1114
        resp = self._call2('mkdir', self._remote_path(relpath),
 
1115
            self._serialise_optional_mode(mode))
 
1116
        self._translate_error(resp)
 
1117
 
 
1118
    def put_bytes(self, relpath, upload_contents, mode=None):
 
1119
        # FIXME: upload_file is probably not safe for non-ascii characters -
 
1120
        # should probably just pass all parameters as length-delimited
 
1121
        # strings?
 
1122
        resp = self._call_with_body_bytes('put',
 
1123
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
 
1124
            upload_contents)
 
1125
        self._translate_error(resp)
 
1126
 
 
1127
    def put_bytes_non_atomic(self, relpath, bytes, mode=None,
 
1128
                             create_parent_dir=False,
 
1129
                             dir_mode=None):
 
1130
        """See Transport.put_bytes_non_atomic."""
 
1131
        # FIXME: no encoding in the transport!
 
1132
        create_parent_str = 'F'
 
1133
        if create_parent_dir:
 
1134
            create_parent_str = 'T'
 
1135
 
 
1136
        resp = self._call_with_body_bytes(
 
1137
            'put_non_atomic',
 
1138
            (self._remote_path(relpath), self._serialise_optional_mode(mode),
 
1139
             create_parent_str, self._serialise_optional_mode(dir_mode)),
 
1140
            bytes)
 
1141
        self._translate_error(resp)
 
1142
 
 
1143
    def put_file(self, relpath, upload_file, mode=None):
 
1144
        # its not ideal to seek back, but currently put_non_atomic_file depends
 
1145
        # on transports not reading before failing - which is a faulty
 
1146
        # assumption I think - RBC 20060915
 
1147
        pos = upload_file.tell()
 
1148
        try:
 
1149
            return self.put_bytes(relpath, upload_file.read(), mode)
 
1150
        except:
 
1151
            upload_file.seek(pos)
 
1152
            raise
 
1153
 
 
1154
    def put_file_non_atomic(self, relpath, f, mode=None,
 
1155
                            create_parent_dir=False,
 
1156
                            dir_mode=None):
 
1157
        return self.put_bytes_non_atomic(relpath, f.read(), mode=mode,
 
1158
                                         create_parent_dir=create_parent_dir,
 
1159
                                         dir_mode=dir_mode)
 
1160
 
 
1161
    def append_file(self, relpath, from_file, mode=None):
 
1162
        return self.append_bytes(relpath, from_file.read(), mode)
 
1163
        
 
1164
    def append_bytes(self, relpath, bytes, mode=None):
 
1165
        resp = self._call_with_body_bytes(
 
1166
            'append',
 
1167
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
 
1168
            bytes)
 
1169
        if resp[0] == 'appended':
 
1170
            return int(resp[1])
 
1171
        self._translate_error(resp)
 
1172
 
 
1173
    def delete(self, relpath):
 
1174
        resp = self._call2('delete', self._remote_path(relpath))
 
1175
        self._translate_error(resp)
 
1176
 
 
1177
    def readv(self, relpath, offsets):
 
1178
        if not offsets:
 
1179
            return
 
1180
 
 
1181
        offsets = list(offsets)
 
1182
 
 
1183
        sorted_offsets = sorted(offsets)
 
1184
        # turn the list of offsets into a stack
 
1185
        offset_stack = iter(offsets)
 
1186
        cur_offset_and_size = offset_stack.next()
 
1187
        coalesced = list(self._coalesce_offsets(sorted_offsets,
 
1188
                               limit=self._max_readv_combine,
 
1189
                               fudge_factor=self._bytes_to_read_before_seek))
 
1190
 
 
1191
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1192
        protocol.call_with_body_readv_array(
 
1193
            ('readv', self._remote_path(relpath)),
 
1194
            [(c.start, c.length) for c in coalesced])
 
1195
        resp = protocol.read_response_tuple(True)
 
1196
 
 
1197
        if resp[0] != 'readv':
 
1198
            # This should raise an exception
 
1199
            protocol.cancel_read_body()
 
1200
            self._translate_error(resp)
 
1201
            return
 
1202
 
 
1203
        # FIXME: this should know how many bytes are needed, for clarity.
 
1204
        data = protocol.read_body_bytes()
 
1205
        # Cache the results, but only until they have been fulfilled
 
1206
        data_map = {}
 
1207
        for c_offset in coalesced:
 
1208
            if len(data) < c_offset.length:
 
1209
                raise errors.ShortReadvError(relpath, c_offset.start,
 
1210
                            c_offset.length, actual=len(data))
 
1211
            for suboffset, subsize in c_offset.ranges:
 
1212
                key = (c_offset.start+suboffset, subsize)
 
1213
                data_map[key] = data[suboffset:suboffset+subsize]
 
1214
            data = data[c_offset.length:]
 
1215
 
 
1216
            # Now that we've read some data, see if we can yield anything back
 
1217
            while cur_offset_and_size in data_map:
 
1218
                this_data = data_map.pop(cur_offset_and_size)
 
1219
                yield cur_offset_and_size[0], this_data
 
1220
                cur_offset_and_size = offset_stack.next()
 
1221
 
 
1222
    def rename(self, rel_from, rel_to):
 
1223
        self._call('rename',
 
1224
                   self._remote_path(rel_from),
 
1225
                   self._remote_path(rel_to))
 
1226
 
 
1227
    def move(self, rel_from, rel_to):
 
1228
        self._call('move',
 
1229
                   self._remote_path(rel_from),
 
1230
                   self._remote_path(rel_to))
 
1231
 
 
1232
    def rmdir(self, relpath):
 
1233
        resp = self._call('rmdir', self._remote_path(relpath))
 
1234
 
 
1235
    def _translate_error(self, resp, orig_path=None):
 
1236
        """Raise an exception from a response"""
 
1237
        if resp is None:
 
1238
            what = None
 
1239
        else:
 
1240
            what = resp[0]
 
1241
        if what == 'ok':
 
1242
            return
 
1243
        elif what == 'NoSuchFile':
 
1244
            if orig_path is not None:
 
1245
                error_path = orig_path
 
1246
            else:
 
1247
                error_path = resp[1]
 
1248
            raise errors.NoSuchFile(error_path)
 
1249
        elif what == 'error':
 
1250
            raise errors.SmartProtocolError(unicode(resp[1]))
 
1251
        elif what == 'FileExists':
 
1252
            raise errors.FileExists(resp[1])
 
1253
        elif what == 'DirectoryNotEmpty':
 
1254
            raise errors.DirectoryNotEmpty(resp[1])
 
1255
        elif what == 'ShortReadvError':
 
1256
            raise errors.ShortReadvError(resp[1], int(resp[2]),
 
1257
                                         int(resp[3]), int(resp[4]))
 
1258
        elif what in ('UnicodeEncodeError', 'UnicodeDecodeError'):
 
1259
            encoding = str(resp[1]) # encoding must always be a string
 
1260
            val = resp[2]
 
1261
            start = int(resp[3])
 
1262
            end = int(resp[4])
 
1263
            reason = str(resp[5]) # reason must always be a string
 
1264
            if val.startswith('u:'):
 
1265
                val = val[2:]
 
1266
            elif val.startswith('s:'):
 
1267
                val = val[2:].decode('base64')
 
1268
            if what == 'UnicodeDecodeError':
 
1269
                raise UnicodeDecodeError(encoding, val, start, end, reason)
 
1270
            elif what == 'UnicodeEncodeError':
 
1271
                raise UnicodeEncodeError(encoding, val, start, end, reason)
 
1272
        elif what == "ReadOnlyError":
 
1273
            raise errors.TransportNotPossible('readonly transport')
 
1274
        else:
 
1275
            raise errors.SmartProtocolError('unexpected smart server error: %r' % (resp,))
 
1276
 
 
1277
    def disconnect(self):
 
1278
        self._medium.disconnect()
 
1279
 
 
1280
    def delete_tree(self, relpath):
 
1281
        raise errors.TransportNotPossible('readonly transport')
 
1282
 
 
1283
    def stat(self, relpath):
 
1284
        resp = self._call2('stat', self._remote_path(relpath))
 
1285
        if resp[0] == 'stat':
 
1286
            return SmartStat(int(resp[1]), int(resp[2], 8))
 
1287
        else:
 
1288
            self._translate_error(resp)
 
1289
 
 
1290
    ## def lock_read(self, relpath):
 
1291
    ##     """Lock the given file for shared (read) access.
 
1292
    ##     :return: A lock object, which should be passed to Transport.unlock()
 
1293
    ##     """
 
1294
    ##     # The old RemoteBranch ignore lock for reading, so we will
 
1295
    ##     # continue that tradition and return a bogus lock object.
 
1296
    ##     class BogusLock(object):
 
1297
    ##         def __init__(self, path):
 
1298
    ##             self.path = path
 
1299
    ##         def unlock(self):
 
1300
    ##             pass
 
1301
    ##     return BogusLock(relpath)
 
1302
 
 
1303
    def listable(self):
 
1304
        return True
 
1305
 
 
1306
    def list_dir(self, relpath):
 
1307
        resp = self._call2('list_dir', self._remote_path(relpath))
 
1308
        if resp[0] == 'names':
 
1309
            return [name.encode('ascii') for name in resp[1:]]
 
1310
        else:
 
1311
            self._translate_error(resp)
 
1312
 
 
1313
    def iter_files_recursive(self):
 
1314
        resp = self._call2('iter_files_recursive', self._remote_path(''))
 
1315
        if resp[0] == 'names':
 
1316
            return resp[1:]
 
1317
        else:
 
1318
            self._translate_error(resp)
 
1319
 
 
1320
 
 
1321
class SmartClientMediumRequest(object):
 
1322
    """A request on a SmartClientMedium.
 
1323
 
 
1324
    Each request allows bytes to be provided to it via accept_bytes, and then
 
1325
    the response bytes to be read via read_bytes.
 
1326
 
 
1327
    For instance:
 
1328
    request.accept_bytes('123')
 
1329
    request.finished_writing()
 
1330
    result = request.read_bytes(3)
 
1331
    request.finished_reading()
 
1332
 
 
1333
    It is up to the individual SmartClientMedium whether multiple concurrent
 
1334
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
1335
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
1336
    details on concurrency and pipelining.
 
1337
    """
 
1338
 
 
1339
    def __init__(self, medium):
 
1340
        """Construct a SmartClientMediumRequest for the medium medium."""
 
1341
        self._medium = medium
 
1342
        # we track state by constants - we may want to use the same
 
1343
        # pattern as BodyReader if it gets more complex.
 
1344
        # valid states are: "writing", "reading", "done"
 
1345
        self._state = "writing"
 
1346
 
 
1347
    def accept_bytes(self, bytes):
 
1348
        """Accept bytes for inclusion in this request.
 
1349
 
 
1350
        This method may not be be called after finished_writing() has been
 
1351
        called.  It depends upon the Medium whether or not the bytes will be
 
1352
        immediately transmitted. Message based Mediums will tend to buffer the
 
1353
        bytes until finished_writing() is called.
 
1354
 
 
1355
        :param bytes: A bytestring.
 
1356
        """
 
1357
        if self._state != "writing":
 
1358
            raise errors.WritingCompleted(self)
 
1359
        self._accept_bytes(bytes)
 
1360
 
 
1361
    def _accept_bytes(self, bytes):
 
1362
        """Helper for accept_bytes.
 
1363
 
 
1364
        Accept_bytes checks the state of the request to determing if bytes
 
1365
        should be accepted. After that it hands off to _accept_bytes to do the
 
1366
        actual acceptance.
 
1367
        """
 
1368
        raise NotImplementedError(self._accept_bytes)
 
1369
 
 
1370
    def finished_reading(self):
 
1371
        """Inform the request that all desired data has been read.
 
1372
 
 
1373
        This will remove the request from the pipeline for its medium (if the
 
1374
        medium supports pipelining) and any further calls to methods on the
 
1375
        request will raise ReadingCompleted.
 
1376
        """
 
1377
        if self._state == "writing":
 
1378
            raise errors.WritingNotComplete(self)
 
1379
        if self._state != "reading":
 
1380
            raise errors.ReadingCompleted(self)
 
1381
        self._state = "done"
 
1382
        self._finished_reading()
 
1383
 
 
1384
    def _finished_reading(self):
 
1385
        """Helper for finished_reading.
 
1386
 
 
1387
        finished_reading checks the state of the request to determine if 
 
1388
        finished_reading is allowed, and if it is hands off to _finished_reading
 
1389
        to perform the action.
 
1390
        """
 
1391
        raise NotImplementedError(self._finished_reading)
 
1392
 
 
1393
    def finished_writing(self):
 
1394
        """Finish the writing phase of this request.
 
1395
 
 
1396
        This will flush all pending data for this request along the medium.
 
1397
        After calling finished_writing, you may not call accept_bytes anymore.
 
1398
        """
 
1399
        if self._state != "writing":
 
1400
            raise errors.WritingCompleted(self)
 
1401
        self._state = "reading"
 
1402
        self._finished_writing()
 
1403
 
 
1404
    def _finished_writing(self):
 
1405
        """Helper for finished_writing.
 
1406
 
 
1407
        finished_writing checks the state of the request to determine if 
 
1408
        finished_writing is allowed, and if it is hands off to _finished_writing
 
1409
        to perform the action.
 
1410
        """
 
1411
        raise NotImplementedError(self._finished_writing)
 
1412
 
 
1413
    def read_bytes(self, count):
 
1414
        """Read bytes from this requests response.
 
1415
 
 
1416
        This method will block and wait for count bytes to be read. It may not
 
1417
        be invoked until finished_writing() has been called - this is to ensure
 
1418
        a message-based approach to requests, for compatability with message
 
1419
        based mediums like HTTP.
 
1420
        """
 
1421
        if self._state == "writing":
 
1422
            raise errors.WritingNotComplete(self)
 
1423
        if self._state != "reading":
 
1424
            raise errors.ReadingCompleted(self)
 
1425
        return self._read_bytes(count)
 
1426
 
 
1427
    def _read_bytes(self, count):
 
1428
        """Helper for read_bytes.
 
1429
 
 
1430
        read_bytes checks the state of the request to determing if bytes
 
1431
        should be read. After that it hands off to _read_bytes to do the
 
1432
        actual read.
 
1433
        """
 
1434
        raise NotImplementedError(self._read_bytes)
 
1435
 
 
1436
 
 
1437
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
1438
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
1439
 
 
1440
    def __init__(self, medium):
 
1441
        SmartClientMediumRequest.__init__(self, medium)
 
1442
        # check that we are safe concurrency wise. If some streams start
 
1443
        # allowing concurrent requests - i.e. via multiplexing - then this
 
1444
        # assert should be moved to SmartClientStreamMedium.get_request,
 
1445
        # and the setting/unsetting of _current_request likewise moved into
 
1446
        # that class : but its unneeded overhead for now. RBC 20060922
 
1447
        if self._medium._current_request is not None:
 
1448
            raise errors.TooManyConcurrentRequests(self._medium)
 
1449
        self._medium._current_request = self
 
1450
 
 
1451
    def _accept_bytes(self, bytes):
 
1452
        """See SmartClientMediumRequest._accept_bytes.
 
1453
        
 
1454
        This forwards to self._medium._accept_bytes because we are operating
 
1455
        on the mediums stream.
 
1456
        """
 
1457
        self._medium._accept_bytes(bytes)
 
1458
 
 
1459
    def _finished_reading(self):
 
1460
        """See SmartClientMediumRequest._finished_reading.
 
1461
 
 
1462
        This clears the _current_request on self._medium to allow a new 
 
1463
        request to be created.
 
1464
        """
 
1465
        assert self._medium._current_request is self
 
1466
        self._medium._current_request = None
 
1467
        
 
1468
    def _finished_writing(self):
 
1469
        """See SmartClientMediumRequest._finished_writing.
 
1470
 
 
1471
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
1472
        """
 
1473
        self._medium._flush()
 
1474
 
 
1475
    def _read_bytes(self, count):
 
1476
        """See SmartClientMediumRequest._read_bytes.
 
1477
        
 
1478
        This forwards to self._medium._read_bytes because we are operating
 
1479
        on the mediums stream.
 
1480
        """
 
1481
        return self._medium._read_bytes(count)
 
1482
 
 
1483
 
 
1484
class SmartClientRequestProtocolOne(SmartProtocolBase):
 
1485
    """The client-side protocol for smart version 1."""
 
1486
 
 
1487
    def __init__(self, request):
 
1488
        """Construct a SmartClientRequestProtocolOne.
 
1489
 
 
1490
        :param request: A SmartClientMediumRequest to serialise onto and
 
1491
            deserialise from.
 
1492
        """
 
1493
        self._request = request
 
1494
        self._body_buffer = None
 
1495
 
 
1496
    def call(self, *args):
 
1497
        bytes = _encode_tuple(args)
 
1498
        self._request.accept_bytes(bytes)
 
1499
        self._request.finished_writing()
 
1500
 
 
1501
    def call_with_body_bytes(self, args, body):
 
1502
        """Make a remote call of args with body bytes 'body'.
 
1503
 
 
1504
        After calling this, call read_response_tuple to find the result out.
 
1505
        """
 
1506
        bytes = _encode_tuple(args)
 
1507
        self._request.accept_bytes(bytes)
 
1508
        bytes = self._encode_bulk_data(body)
 
1509
        self._request.accept_bytes(bytes)
 
1510
        self._request.finished_writing()
 
1511
 
 
1512
    def call_with_body_readv_array(self, args, body):
 
1513
        """Make a remote call with a readv array.
 
1514
 
 
1515
        The body is encoded with one line per readv offset pair. The numbers in
 
1516
        each pair are separated by a comma, and no trailing \n is emitted.
 
1517
        """
 
1518
        bytes = _encode_tuple(args)
 
1519
        self._request.accept_bytes(bytes)
 
1520
        readv_bytes = self._serialise_offsets(body)
 
1521
        bytes = self._encode_bulk_data(readv_bytes)
 
1522
        self._request.accept_bytes(bytes)
 
1523
        self._request.finished_writing()
 
1524
 
 
1525
    def cancel_read_body(self):
 
1526
        """After expecting a body, a response code may indicate one otherwise.
 
1527
 
 
1528
        This method lets the domain client inform the protocol that no body
 
1529
        will be transmitted. This is a terminal method: after calling it the
 
1530
        protocol is not able to be used further.
 
1531
        """
 
1532
        self._request.finished_reading()
 
1533
 
 
1534
    def read_response_tuple(self, expect_body=False):
 
1535
        """Read a response tuple from the wire.
 
1536
 
 
1537
        This should only be called once.
 
1538
        """
 
1539
        result = self._recv_tuple()
 
1540
        if not expect_body:
 
1541
            self._request.finished_reading()
 
1542
        return result
 
1543
 
 
1544
    def read_body_bytes(self, count=-1):
 
1545
        """Read bytes from the body, decoding into a byte stream.
 
1546
        
 
1547
        We read all bytes at once to ensure we've checked the trailer for 
 
1548
        errors, and then feed the buffer back as read_body_bytes is called.
 
1549
        """
 
1550
        if self._body_buffer is not None:
 
1551
            return self._body_buffer.read(count)
 
1552
        _body_decoder = LengthPrefixedBodyDecoder()
 
1553
 
 
1554
        while not _body_decoder.finished_reading:
 
1555
            bytes_wanted = _body_decoder.next_read_size()
 
1556
            bytes = self._request.read_bytes(bytes_wanted)
 
1557
            _body_decoder.accept_bytes(bytes)
 
1558
        self._request.finished_reading()
 
1559
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
1560
        # XXX: TODO check the trailer result.
 
1561
        return self._body_buffer.read(count)
 
1562
 
 
1563
    def _recv_tuple(self):
 
1564
        """Receive a tuple from the medium request."""
 
1565
        line = ''
 
1566
        while not line or line[-1] != '\n':
 
1567
            # TODO: this is inefficient - but tuples are short.
 
1568
            new_char = self._request.read_bytes(1)
 
1569
            line += new_char
 
1570
            assert new_char != '', "end of file reading from server."
 
1571
        return _decode_tuple(line)
 
1572
 
 
1573
    def query_version(self):
 
1574
        """Return protocol version number of the server."""
 
1575
        self.call('hello')
 
1576
        resp = self.read_response_tuple()
 
1577
        if resp == ('ok', '1'):
 
1578
            return 1
 
1579
        else:
 
1580
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
1581
 
 
1582
 
 
1583
class SmartClientMedium(object):
 
1584
    """Smart client is a medium for sending smart protocol requests over."""
 
1585
 
 
1586
    def disconnect(self):
 
1587
        """If this medium maintains a persistent connection, close it.
 
1588
        
 
1589
        The default implementation does nothing.
 
1590
        """
 
1591
        
 
1592
 
 
1593
class SmartClientStreamMedium(SmartClientMedium):
 
1594
    """Stream based medium common class.
 
1595
 
 
1596
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
1597
    SmartClientStreamMediumRequest for their requests, and should implement
 
1598
    _accept_bytes and _read_bytes to allow the request objects to send and
 
1599
    receive bytes.
 
1600
    """
 
1601
 
 
1602
    def __init__(self):
 
1603
        self._current_request = None
 
1604
 
 
1605
    def accept_bytes(self, bytes):
 
1606
        self._accept_bytes(bytes)
 
1607
 
 
1608
    def __del__(self):
 
1609
        """The SmartClientStreamMedium knows how to close the stream when it is
 
1610
        finished with it.
 
1611
        """
 
1612
        self.disconnect()
 
1613
 
 
1614
    def _flush(self):
 
1615
        """Flush the output stream.
 
1616
        
 
1617
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
1618
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
1619
        """
 
1620
        raise NotImplementedError(self._flush)
 
1621
 
 
1622
    def get_request(self):
 
1623
        """See SmartClientMedium.get_request().
 
1624
 
 
1625
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
1626
        for get_request.
 
1627
        """
 
1628
        return SmartClientStreamMediumRequest(self)
 
1629
 
 
1630
    def read_bytes(self, count):
 
1631
        return self._read_bytes(count)
 
1632
 
 
1633
 
 
1634
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
1635
    """A client medium using simple pipes.
 
1636
    
 
1637
    This client does not manage the pipes: it assumes they will always be open.
 
1638
    """
 
1639
 
 
1640
    def __init__(self, readable_pipe, writeable_pipe):
 
1641
        SmartClientStreamMedium.__init__(self)
 
1642
        self._readable_pipe = readable_pipe
 
1643
        self._writeable_pipe = writeable_pipe
 
1644
 
 
1645
    def _accept_bytes(self, bytes):
 
1646
        """See SmartClientStreamMedium.accept_bytes."""
 
1647
        self._writeable_pipe.write(bytes)
 
1648
 
 
1649
    def _flush(self):
 
1650
        """See SmartClientStreamMedium._flush()."""
 
1651
        self._writeable_pipe.flush()
 
1652
 
 
1653
    def _read_bytes(self, count):
 
1654
        """See SmartClientStreamMedium._read_bytes."""
 
1655
        return self._readable_pipe.read(count)
 
1656
 
 
1657
 
 
1658
class SmartSSHClientMedium(SmartClientStreamMedium):
 
1659
    """A client medium using SSH."""
 
1660
    
 
1661
    def __init__(self, host, port=None, username=None, password=None,
 
1662
            vendor=None):
 
1663
        """Creates a client that will connect on the first use.
 
1664
        
 
1665
        :param vendor: An optional override for the ssh vendor to use. See
 
1666
            bzrlib.transport.ssh for details on ssh vendors.
 
1667
        """
 
1668
        SmartClientStreamMedium.__init__(self)
 
1669
        self._connected = False
 
1670
        self._host = host
 
1671
        self._password = password
 
1672
        self._port = port
 
1673
        self._username = username
 
1674
        self._read_from = None
 
1675
        self._ssh_connection = None
 
1676
        self._vendor = vendor
 
1677
        self._write_to = None
 
1678
 
 
1679
    def _accept_bytes(self, bytes):
 
1680
        """See SmartClientStreamMedium.accept_bytes."""
 
1681
        self._ensure_connection()
 
1682
        self._write_to.write(bytes)
 
1683
 
 
1684
    def disconnect(self):
 
1685
        """See SmartClientMedium.disconnect()."""
 
1686
        if not self._connected:
 
1687
            return
 
1688
        self._read_from.close()
 
1689
        self._write_to.close()
 
1690
        self._ssh_connection.close()
 
1691
        self._connected = False
 
1692
 
 
1693
    def _ensure_connection(self):
 
1694
        """Connect this medium if not already connected."""
 
1695
        if self._connected:
 
1696
            return
 
1697
        executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
1698
        if self._vendor is None:
 
1699
            vendor = ssh._get_ssh_vendor()
 
1700
        else:
 
1701
            vendor = self._vendor
 
1702
        self._ssh_connection = vendor.connect_ssh(self._username,
 
1703
                self._password, self._host, self._port,
 
1704
                command=[executable, 'serve', '--inet', '--directory=/',
 
1705
                         '--allow-writes'])
 
1706
        self._read_from, self._write_to = \
 
1707
            self._ssh_connection.get_filelike_channels()
 
1708
        self._connected = True
 
1709
 
 
1710
    def _flush(self):
 
1711
        """See SmartClientStreamMedium._flush()."""
 
1712
        self._write_to.flush()
 
1713
 
 
1714
    def _read_bytes(self, count):
 
1715
        """See SmartClientStreamMedium.read_bytes."""
 
1716
        if not self._connected:
 
1717
            raise errors.MediumNotConnected(self)
 
1718
        return self._read_from.read(count)
 
1719
 
 
1720
 
 
1721
class SmartTCPClientMedium(SmartClientStreamMedium):
 
1722
    """A client medium using TCP."""
 
1723
    
 
1724
    def __init__(self, host, port):
 
1725
        """Creates a client that will connect on the first use."""
 
1726
        SmartClientStreamMedium.__init__(self)
 
1727
        self._connected = False
 
1728
        self._host = host
 
1729
        self._port = port
 
1730
        self._socket = None
 
1731
 
 
1732
    def _accept_bytes(self, bytes):
 
1733
        """See SmartClientMedium.accept_bytes."""
 
1734
        self._ensure_connection()
 
1735
        self._socket.sendall(bytes)
 
1736
 
 
1737
    def disconnect(self):
 
1738
        """See SmartClientMedium.disconnect()."""
 
1739
        if not self._connected:
 
1740
            return
 
1741
        self._socket.close()
 
1742
        self._socket = None
 
1743
        self._connected = False
 
1744
 
 
1745
    def _ensure_connection(self):
 
1746
        """Connect this medium if not already connected."""
 
1747
        if self._connected:
 
1748
            return
 
1749
        self._socket = socket.socket()
 
1750
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
1751
        result = self._socket.connect_ex((self._host, int(self._port)))
 
1752
        if result:
 
1753
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
1754
                    (self._host, self._port, os.strerror(result)))
 
1755
        self._connected = True
 
1756
 
 
1757
    def _flush(self):
 
1758
        """See SmartClientStreamMedium._flush().
 
1759
        
 
1760
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
1761
        add a means to do a flush, but that can be done in the future.
 
1762
        """
 
1763
 
 
1764
    def _read_bytes(self, count):
 
1765
        """See SmartClientMedium.read_bytes."""
 
1766
        if not self._connected:
 
1767
            raise errors.MediumNotConnected(self)
 
1768
        return self._socket.recv(count)
 
1769
 
 
1770
 
 
1771
class SmartTCPTransport(SmartTransport):
 
1772
    """Connection to smart server over plain tcp.
 
1773
    
 
1774
    This is essentially just a factory to get 'RemoteTransport(url,
 
1775
        SmartTCPClientMedium).
 
1776
    """
 
1777
 
 
1778
    def __init__(self, url):
 
1779
        _scheme, _username, _password, _host, _port, _path = \
 
1780
            transport.split_url(url)
 
1781
        try:
 
1782
            _port = int(_port)
 
1783
        except (ValueError, TypeError), e:
 
1784
            raise errors.InvalidURL(path=url, extra="invalid port %s" % _port)
 
1785
        medium = SmartTCPClientMedium(_host, _port)
 
1786
        super(SmartTCPTransport, self).__init__(url, medium=medium)
 
1787
 
 
1788
 
 
1789
try:
 
1790
    from bzrlib.transport import ssh
 
1791
except errors.ParamikoNotPresent:
 
1792
    # no paramiko, no SSHTransport.
 
1793
    pass
 
1794
else:
 
1795
    class SmartSSHTransport(SmartTransport):
 
1796
        """Connection to smart server over SSH.
 
1797
 
 
1798
        This is essentially just a factory to get 'RemoteTransport(url,
 
1799
            SmartSSHClientMedium).
 
1800
        """
 
1801
 
 
1802
        def __init__(self, url):
 
1803
            _scheme, _username, _password, _host, _port, _path = \
 
1804
                transport.split_url(url)
 
1805
            try:
 
1806
                if _port is not None:
 
1807
                    _port = int(_port)
 
1808
            except (ValueError, TypeError), e:
 
1809
                raise errors.InvalidURL(path=url, extra="invalid port %s" % 
 
1810
                    _port)
 
1811
            medium = SmartSSHClientMedium(_host, _port, _username, _password)
 
1812
            super(SmartSSHTransport, self).__init__(url, medium=medium)
 
1813
 
 
1814
 
 
1815
def get_test_permutations():
 
1816
    """Return (transport, server) permutations for testing."""
 
1817
    ### We may need a little more test framework support to construct an
 
1818
    ### appropriate RemoteTransport in the future.
 
1819
    return [(SmartTCPTransport, SmartTCPServer_for_testing)]