[dataimport] Properly escape strings sent to COPY FROM (closes #5278743)

See http://www.postgresql.org/docs/9.1/static/sql-copy.html#AEN64296 for escaping codes.

authorRémi Cardona <remi.cardona@logilab.fr>
changesetefbbf1e93a04
branchdefault
phasepublic
hiddenno
parent revision#52a976c5d27a [connection] provide some missing documentation bits
child revision#fd7dd485c745 [dataimport] Turn dataimport.py into a package., #31327bd26931 [dataimport] Turn the module into a package, #015d053f6843 [dataimport] Turn the module into a package
files modified by this revision
dataimport.py
test/unittest_dataimport.py
# HG changeset patch
# User Rémi Cardona <remi.cardona@logilab.fr>
# Date 1415370810 -3600
# Fri Nov 07 15:33:30 2014 +0100
# Node ID efbbf1e93a046ba250b50570c7c602029dd2c346
# Parent 52a976c5d27a7c75f0146f55db357cc85f5aeef3
[dataimport] Properly escape strings sent to COPY FROM (closes #5278743)

See http://www.postgresql.org/docs/9.1/static/sql-copy.html#AEN64296
for escaping codes.

diff --git a/dataimport.py b/dataimport.py
@@ -447,26 +447,16 @@
1  def _copyfrom_buffer_convert_string(value, **opts):
2      '''Convert string value.
3 
4      Recognized keywords:
5      :encoding: resulting string encoding (default: utf-8)
6 -    :replace_sep: character used when input contains characters
7 -                  that conflict with the column separator.
8      '''
9      encoding = opts.get('encoding','utf-8')
10 -    replace_sep = opts.get('replace_sep', None)
11 -    # Remove separators used in string formatting
12 -    for _char in (u'\t', u'\r', u'\n'):
13 -        if _char in value:
14 -            # If a replace_sep is given, replace
15 -            # the separator
16 -            # (and thus avoid empty buffer)
17 -            if replace_sep is None:
18 -                raise ValueError('conflicting separator: '
19 -                                 'you must provide the replace_sep option')
20 -            value = value.replace(_char, replace_sep)
21 -        value = value.replace('\\', r'\\')
22 +    escape_chars = ((u'\\', ur'\\'), (u'\t', u'\\t'), (u'\r', u'\\r'),
23 +                    (u'\n', u'\\n'))
24 +    for char, replace in escape_chars:
25 +        value = value.replace(char, replace)
26      if isinstance(value, unicode):
27          value = value.encode(encoding)
28      return value
29 
30  def _copyfrom_buffer_convert_date(value, **opts):
diff --git a/test/unittest_dataimport.py b/test/unittest_dataimport.py
@@ -47,12 +47,13 @@
31          # simple
32          self.assertEqual('babar', cnvt('babar'))
33          # unicode
34          self.assertEqual('\xc3\xa9l\xc3\xa9phant', cnvt(u'éléphant'))
35          self.assertEqual('\xe9l\xe9phant', cnvt(u'éléphant', encoding='latin1'))
36 -        self.assertEqual('babar#', cnvt('babar\t', replace_sep='#'))
37 -        self.assertRaises(ValueError, cnvt, 'babar\t')
38 +        # escaping
39 +        self.assertEqual('babar\\tceleste\\n', cnvt('babar\tceleste\n'))
40 +        self.assertEqual(r'C:\\new\tC:\\test', cnvt('C:\\new\tC:\\test'))
41 
42      def test_convert_date(self):
43          cnvt = dataimport._copyfrom_buffer_convert_date
44          self.assertEqual('0666-01-13', cnvt(DT.date(666, 1, 13)))
45