/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

Fix merge2 to use PatienceSequenceMatcher

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004, 2005 by Canonical Ltd
 
2
 
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
 
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
 
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
 
 
18
# mbp: "you know that thing where cvs gives you conflict markers?"
 
19
# s: "i hate that."
 
20
 
 
21
 
 
22
from bzrlib.errors import CantReprocessAndShowBase
 
23
from bzrlib.patiencediff import PatienceSequenceMatcher
 
24
from bzrlib.textfile import check_text_lines
 
25
 
 
26
SequenceMatcher = PatienceSequenceMatcher
 
27
 
 
28
 
 
29
def intersect(ra, rb):
 
30
    """Given two ranges return the range where they intersect or None.
 
31
 
 
32
    >>> intersect((0, 10), (0, 6))
 
33
    (0, 6)
 
34
    >>> intersect((0, 10), (5, 15))
 
35
    (5, 10)
 
36
    >>> intersect((0, 10), (10, 15))
 
37
    >>> intersect((0, 9), (10, 15))
 
38
    >>> intersect((0, 9), (7, 15))
 
39
    (7, 9)
 
40
    """
 
41
    assert ra[0] <= ra[1]
 
42
    assert rb[0] <= rb[1]
 
43
    
 
44
    sa = max(ra[0], rb[0])
 
45
    sb = min(ra[1], rb[1])
 
46
    if sa < sb:
 
47
        return sa, sb
 
48
    else:
 
49
        return None
 
50
 
 
51
 
 
52
def compare_range(a, astart, aend, b, bstart, bend):
 
53
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
54
    """
 
55
    if (aend-astart) != (bend-bstart):
 
56
        return False
 
57
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
58
        if a[ia] != b[ib]:
 
59
            return False
 
60
    else:
 
61
        return True
 
62
        
 
63
 
 
64
 
 
65
 
 
66
class Merge3(object):
 
67
    """3-way merge of texts.
 
68
 
 
69
    Given BASE, OTHER, THIS, tries to produce a combined text
 
70
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
71
    All three will typically be sequences of lines."""
 
72
    def __init__(self, base, a, b):
 
73
        check_text_lines(base)
 
74
        check_text_lines(a)
 
75
        check_text_lines(b)
 
76
        self.base = base
 
77
        self.a = a
 
78
        self.b = b
 
79
 
 
80
 
 
81
 
 
82
    def merge_lines(self,
 
83
                    name_a=None,
 
84
                    name_b=None,
 
85
                    name_base=None,
 
86
                    start_marker='<<<<<<<',
 
87
                    mid_marker='=======',
 
88
                    end_marker='>>>>>>>',
 
89
                    base_marker=None,
 
90
                    reprocess=False):
 
91
        """Return merge in cvs-like form.
 
92
        """
 
93
        if base_marker and reprocess:
 
94
            raise CantReprocessAndShowBase()
 
95
        if name_a:
 
96
            start_marker = start_marker + ' ' + name_a
 
97
        if name_b:
 
98
            end_marker = end_marker + ' ' + name_b
 
99
        if name_base and base_marker:
 
100
            base_marker = base_marker + ' ' + name_base
 
101
        merge_regions = self.merge_regions()
 
102
        if reprocess is True:
 
103
            merge_regions = self.reprocess_merge_regions(merge_regions)
 
104
        for t in merge_regions:
 
105
            what = t[0]
 
106
            if what == 'unchanged':
 
107
                for i in range(t[1], t[2]):
 
108
                    yield self.base[i]
 
109
            elif what == 'a' or what == 'same':
 
110
                for i in range(t[1], t[2]):
 
111
                    yield self.a[i]
 
112
            elif what == 'b':
 
113
                for i in range(t[1], t[2]):
 
114
                    yield self.b[i]
 
115
            elif what == 'conflict':
 
116
                yield start_marker + '\n'
 
117
                for i in range(t[3], t[4]):
 
118
                    yield self.a[i]
 
119
                if base_marker is not None:
 
120
                    yield base_marker + '\n'
 
121
                    for i in range(t[1], t[2]):
 
122
                        yield self.base[i]
 
123
                yield mid_marker + '\n'
 
124
                for i in range(t[5], t[6]):
 
125
                    yield self.b[i]
 
126
                yield end_marker + '\n'
 
127
            else:
 
128
                raise ValueError(what)
 
129
        
 
130
        
 
131
 
 
132
 
 
133
 
 
134
    def merge_annotated(self):
 
135
        """Return merge with conflicts, showing origin of lines.
 
136
 
 
137
        Most useful for debugging merge.        
 
138
        """
 
139
        for t in self.merge_regions():
 
140
            what = t[0]
 
141
            if what == 'unchanged':
 
142
                for i in range(t[1], t[2]):
 
143
                    yield 'u | ' + self.base[i]
 
144
            elif what == 'a' or what == 'same':
 
145
                for i in range(t[1], t[2]):
 
146
                    yield what[0] + ' | ' + self.a[i]
 
147
            elif what == 'b':
 
148
                for i in range(t[1], t[2]):
 
149
                    yield 'b | ' + self.b[i]
 
150
            elif what == 'conflict':
 
151
                yield '<<<<\n'
 
152
                for i in range(t[3], t[4]):
 
153
                    yield 'A | ' + self.a[i]
 
154
                yield '----\n'
 
155
                for i in range(t[5], t[6]):
 
156
                    yield 'B | ' + self.b[i]
 
157
                yield '>>>>\n'
 
158
            else:
 
159
                raise ValueError(what)
 
160
        
 
161
        
 
162
 
 
163
 
 
164
 
 
165
    def merge_groups(self):
 
166
        """Yield sequence of line groups.  Each one is a tuple:
 
167
 
 
168
        'unchanged', lines
 
169
             Lines unchanged from base
 
170
 
 
171
        'a', lines
 
172
             Lines taken from a
 
173
 
 
174
        'same', lines
 
175
             Lines taken from a (and equal to b)
 
176
 
 
177
        'b', lines
 
178
             Lines taken from b
 
179
 
 
180
        'conflict', base_lines, a_lines, b_lines
 
181
             Lines from base were changed to either a or b and conflict.
 
182
        """
 
183
        for t in self.merge_regions():
 
184
            what = t[0]
 
185
            if what == 'unchanged':
 
186
                yield what, self.base[t[1]:t[2]]
 
187
            elif what == 'a' or what == 'same':
 
188
                yield what, self.a[t[1]:t[2]]
 
189
            elif what == 'b':
 
190
                yield what, self.b[t[1]:t[2]]
 
191
            elif what == 'conflict':
 
192
                yield (what,
 
193
                       self.base[t[1]:t[2]],
 
194
                       self.a[t[3]:t[4]],
 
195
                       self.b[t[5]:t[6]])
 
196
            else:
 
197
                raise ValueError(what)
 
198
 
 
199
 
 
200
    def merge_regions(self):
 
201
        """Return sequences of matching and conflicting regions.
 
202
 
 
203
        This returns tuples, where the first value says what kind we
 
204
        have:
 
205
 
 
206
        'unchanged', start, end
 
207
             Take a region of base[start:end]
 
208
 
 
209
        'same', astart, aend
 
210
             b and a are different from base but give the same result
 
211
 
 
212
        'a', start, end
 
213
             Non-clashing insertion from a[start:end]
 
214
 
 
215
        Method is as follows:
 
216
 
 
217
        The two sequences align only on regions which match the base
 
218
        and both descendents.  These are found by doing a two-way diff
 
219
        of each one against the base, and then finding the
 
220
        intersections between those regions.  These "sync regions"
 
221
        are by definition unchanged in both and easily dealt with.
 
222
 
 
223
        The regions in between can be in any of three cases:
 
224
        conflicted, or changed on only one side.
 
225
        """
 
226
 
 
227
        # section a[0:ia] has been disposed of, etc
 
228
        iz = ia = ib = 0
 
229
        
 
230
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
231
            #print 'match base [%d:%d]' % (zmatch, zend)
 
232
            
 
233
            matchlen = zend - zmatch
 
234
            assert matchlen >= 0
 
235
            assert matchlen == (aend - amatch)
 
236
            assert matchlen == (bend - bmatch)
 
237
            
 
238
            len_a = amatch - ia
 
239
            len_b = bmatch - ib
 
240
            len_base = zmatch - iz
 
241
            assert len_a >= 0
 
242
            assert len_b >= 0
 
243
            assert len_base >= 0
 
244
 
 
245
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
246
 
 
247
            if len_a or len_b:
 
248
                # try to avoid actually slicing the lists
 
249
                equal_a = compare_range(self.a, ia, amatch,
 
250
                                        self.base, iz, zmatch)
 
251
                equal_b = compare_range(self.b, ib, bmatch,
 
252
                                        self.base, iz, zmatch)
 
253
                same = compare_range(self.a, ia, amatch,
 
254
                                     self.b, ib, bmatch)
 
255
 
 
256
                if same:
 
257
                    yield 'same', ia, amatch
 
258
                elif equal_a and not equal_b:
 
259
                    yield 'b', ib, bmatch
 
260
                elif equal_b and not equal_a:
 
261
                    yield 'a', ia, amatch
 
262
                elif not equal_a and not equal_b:
 
263
                    yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
264
                else:
 
265
                    raise AssertionError("can't handle a=b=base but unmatched")
 
266
 
 
267
                ia = amatch
 
268
                ib = bmatch
 
269
            iz = zmatch
 
270
 
 
271
            # if the same part of the base was deleted on both sides
 
272
            # that's OK, we can just skip it.
 
273
 
 
274
                
 
275
            if matchlen > 0:
 
276
                assert ia == amatch
 
277
                assert ib == bmatch
 
278
                assert iz == zmatch
 
279
                
 
280
                yield 'unchanged', zmatch, zend
 
281
                iz = zend
 
282
                ia = aend
 
283
                ib = bend
 
284
    
 
285
 
 
286
    def reprocess_merge_regions(self, merge_regions):
 
287
        """Where there are conflict regions, remove the agreed lines.
 
288
 
 
289
        Lines where both A and B have made the same changes are 
 
290
        eliminated.
 
291
        """
 
292
        for region in merge_regions:
 
293
            if region[0] != "conflict":
 
294
                yield region
 
295
                continue
 
296
            type, iz, zmatch, ia, amatch, ib, bmatch = region
 
297
            a_region = self.a[ia:amatch]
 
298
            b_region = self.b[ib:bmatch]
 
299
            matches = SequenceMatcher(None, a_region, 
 
300
                                      b_region).get_matching_blocks()
 
301
            next_a = ia
 
302
            next_b = ib
 
303
            for region_ia, region_ib, region_len in matches[:-1]:
 
304
                region_ia += ia
 
305
                region_ib += ib
 
306
                reg = self.mismatch_region(next_a, region_ia, next_b,
 
307
                                           region_ib)
 
308
                if reg is not None:
 
309
                    yield reg
 
310
                yield 'same', region_ia, region_len+region_ia
 
311
                next_a = region_ia + region_len
 
312
                next_b = region_ib + region_len
 
313
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
 
314
            if reg is not None:
 
315
                yield reg
 
316
 
 
317
 
 
318
    @staticmethod
 
319
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
 
320
        if next_a < region_ia or next_b < region_ib:
 
321
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
 
322
            
 
323
 
 
324
    def find_sync_regions(self):
 
325
        """Return a list of sync regions, where both descendents match the base.
 
326
 
 
327
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
328
        always a zero-length sync region at the end of all the files.
 
329
        """
 
330
 
 
331
        ia = ib = 0
 
332
        amatches = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
 
333
        bmatches = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
 
334
        len_a = len(amatches)
 
335
        len_b = len(bmatches)
 
336
 
 
337
        sl = []
 
338
 
 
339
        while ia < len_a and ib < len_b:
 
340
            abase, amatch, alen = amatches[ia]
 
341
            bbase, bmatch, blen = bmatches[ib]
 
342
 
 
343
            # there is an unconflicted block at i; how long does it
 
344
            # extend?  until whichever one ends earlier.
 
345
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
346
            if i:
 
347
                intbase = i[0]
 
348
                intend = i[1]
 
349
                intlen = intend - intbase
 
350
 
 
351
                # found a match of base[i[0], i[1]]; this may be less than
 
352
                # the region that matches in either one
 
353
                assert intlen <= alen
 
354
                assert intlen <= blen
 
355
                assert abase <= intbase
 
356
                assert bbase <= intbase
 
357
 
 
358
                asub = amatch + (intbase - abase)
 
359
                bsub = bmatch + (intbase - bbase)
 
360
                aend = asub + intlen
 
361
                bend = bsub + intlen
 
362
 
 
363
                assert self.base[intbase:intend] == self.a[asub:aend], \
 
364
                       (self.base[intbase:intend], self.a[asub:aend])
 
365
 
 
366
                assert self.base[intbase:intend] == self.b[bsub:bend]
 
367
 
 
368
                sl.append((intbase, intend,
 
369
                           asub, aend,
 
370
                           bsub, bend))
 
371
 
 
372
            # advance whichever one ends first in the base text
 
373
            if (abase + alen) < (bbase + blen):
 
374
                ia += 1
 
375
            else:
 
376
                ib += 1
 
377
            
 
378
        intbase = len(self.base)
 
379
        abase = len(self.a)
 
380
        bbase = len(self.b)
 
381
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
382
 
 
383
        return sl
 
384
 
 
385
 
 
386
 
 
387
    def find_unconflicted(self):
 
388
        """Return a list of ranges in base that are not conflicted."""
 
389
        am = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
 
390
        bm = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
 
391
 
 
392
        unc = []
 
393
 
 
394
        while am and bm:
 
395
            # there is an unconflicted block at i; how long does it
 
396
            # extend?  until whichever one ends earlier.
 
397
            a1 = am[0][0]
 
398
            a2 = a1 + am[0][2]
 
399
            b1 = bm[0][0]
 
400
            b2 = b1 + bm[0][2]
 
401
            i = intersect((a1, a2), (b1, b2))
 
402
            if i:
 
403
                unc.append(i)
 
404
 
 
405
            if a2 < b2:
 
406
                del am[0]
 
407
            else:
 
408
                del bm[0]
 
409
                
 
410
        return unc
 
411
 
 
412
 
 
413
def main(argv):
 
414
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
415
    a = file(argv[1], 'rt').readlines()
 
416
    base = file(argv[2], 'rt').readlines()
 
417
    b = file(argv[3], 'rt').readlines()
 
418
 
 
419
    m3 = Merge3(base, a, b)
 
420
 
 
421
    #for sr in m3.find_sync_regions():
 
422
    #    print sr
 
423
 
 
424
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
425
    sys.stdout.writelines(m3.merge_annotated())
 
426
 
 
427
 
 
428
if __name__ == '__main__':
 
429
    import sys
 
430
    sys.exit(main(sys.argv))