copy.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. package gokb
  2. import (
  3. "database/sql/driver"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. )
  9. var (
  10. errCopyInClosed = errors.New("kb: copyin statement has already been closed")
  11. errBinaryCopyNotSupported = errors.New("kb: only text format supported for COPY")
  12. errCopyToNotSupported = errors.New("kb: COPY TO is not supported")
  13. errCopyNotSupportedOutsideTxn = errors.New("kb: COPY is only allowed inside a transaction")
  14. errCopyInProgress = errors.New("kb: COPY in progress")
  15. )
  16. // CopyIn creates a COPY FROM statement which can be prepared with
  17. // Tx.Prepare(). The target table should be visible in search_path.
  18. func CopyIn(table string, columns ...string) string {
  19. stmt := "COPY " + QuoteIdentifier(table) + " ("
  20. for i, col := range columns {
  21. if i != 0 {
  22. stmt += ", "
  23. }
  24. stmt += QuoteIdentifier(col)
  25. }
  26. stmt += ") FROM STDIN"
  27. return stmt
  28. }
  29. // CopyInSchema creates a COPY FROM statement which can be prepared with
  30. // Tx.Prepare().
  31. func CopyInSchema(schema, table string, columns ...string) string {
  32. stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
  33. for i, col := range columns {
  34. if i != 0 {
  35. stmt += ", "
  36. }
  37. stmt += QuoteIdentifier(col)
  38. }
  39. stmt += ") FROM STDIN"
  40. return stmt
  41. }
  42. type copyin struct {
  43. cn *conn
  44. buffer []byte
  45. rowData chan []byte
  46. done chan bool
  47. closed bool
  48. sync.Mutex // guards err
  49. err error
  50. }
  51. const ciBufferSize = 64 * 1024
  52. // flush buffer before the buffer is filled up and needs reallocation
  53. const ciBufferFlushSize = 63 * 1024
  54. func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
  55. if !cn.isInTransaction() {
  56. return nil, errCopyNotSupportedOutsideTxn
  57. }
  58. ci := &copyin{
  59. cn: cn,
  60. buffer: make([]byte, 0, ciBufferSize),
  61. rowData: make(chan []byte),
  62. done: make(chan bool, 1),
  63. }
  64. // add CopyData identifier + 4 bytes for message length
  65. ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
  66. b := cn.writeBuf('Q')
  67. b.string(q)
  68. cn.send(b)
  69. awaitCopyInResponse:
  70. for {
  71. t, r := cn.recv1()
  72. switch t {
  73. case 'G':
  74. if r.byte() != 0 {
  75. err = errBinaryCopyNotSupported
  76. break awaitCopyInResponse
  77. }
  78. go ci.resploop()
  79. return ci, nil
  80. case 'H':
  81. err = errCopyToNotSupported
  82. break awaitCopyInResponse
  83. case 'E':
  84. err = parseError(r)
  85. case 'Z':
  86. if err == nil {
  87. ci.setBad()
  88. errorf("unexpected ReadyForQuery in response to COPY")
  89. }
  90. cn.processReadyForQuery(r)
  91. return nil, err
  92. default:
  93. ci.setBad()
  94. errorf("unknown response for copy query: %q", t)
  95. }
  96. }
  97. // something went wrong, abort COPY before we return
  98. b = cn.writeBuf('f')
  99. b.string(err.Error())
  100. cn.send(b)
  101. for {
  102. t, r := cn.recv1()
  103. switch t {
  104. case 'c', 'C', 'E':
  105. case 'Z':
  106. // correctly aborted, we're done
  107. cn.processReadyForQuery(r)
  108. return nil, err
  109. default:
  110. ci.setBad()
  111. errorf("unknown response for CopyFail: %q", t)
  112. }
  113. }
  114. }
  115. func (ci *copyin) flush(buf []byte) {
  116. // set message length (without message identifier)
  117. binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
  118. _, err := ci.cn.c.Write(buf)
  119. if err != nil {
  120. panic(err)
  121. }
  122. }
  123. func (ci *copyin) resploop() {
  124. for {
  125. var r readBuf
  126. t, err := ci.cn.recvMessage(&r)
  127. if err != nil {
  128. ci.setBad()
  129. ci.setError(err)
  130. ci.done <- true
  131. return
  132. }
  133. switch t {
  134. case 'C':
  135. // complete
  136. case 'N':
  137. if n := ci.cn.noticeHandler; n != nil {
  138. n(parseError(&r))
  139. }
  140. case 'Z':
  141. ci.cn.processReadyForQuery(&r)
  142. ci.done <- true
  143. return
  144. case 'E':
  145. err := parseError(&r)
  146. ci.setError(err)
  147. default:
  148. ci.setBad()
  149. ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
  150. ci.done <- true
  151. return
  152. }
  153. }
  154. }
  155. func (ci *copyin) setBad() {
  156. ci.Lock()
  157. ci.cn.bad = true
  158. ci.Unlock()
  159. }
  160. func (ci *copyin) isBad() bool {
  161. ci.Lock()
  162. b := ci.cn.bad
  163. ci.Unlock()
  164. return b
  165. }
  166. func (ci *copyin) isErrorSet() bool {
  167. ci.Lock()
  168. isSet := (ci.err != nil)
  169. ci.Unlock()
  170. return isSet
  171. }
  172. // setError() sets ci.err if one has not been set already. Caller must not be
  173. // holding ci.Mutex.
  174. func (ci *copyin) setError(err error) {
  175. ci.Lock()
  176. if ci.err == nil {
  177. ci.err = err
  178. }
  179. ci.Unlock()
  180. }
  181. func (ci *copyin) NumInput() int {
  182. return -1
  183. }
  184. func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
  185. return nil, ErrNotSupported
  186. }
  187. // Exec inserts values into the COPY stream. The insert is asynchronous
  188. // and Exec can return errors from previous Exec calls to the same
  189. // COPY stmt.
  190. //
  191. // You need to call Exec(nil) to sync the COPY stream and to get any
  192. // errors from pending data, since Stmt.Close() doesn't return errors
  193. // to the user.
  194. func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
  195. if ci.closed {
  196. return nil, errCopyInClosed
  197. }
  198. if ci.isBad() {
  199. return nil, driver.ErrBadConn
  200. }
  201. defer ci.cn.errRecover(&err)
  202. if ci.isErrorSet() {
  203. return nil, ci.err
  204. }
  205. if len(v) == 0 {
  206. return driver.RowsAffected(0), ci.Close()
  207. }
  208. numValues := len(v)
  209. for i, value := range v {
  210. ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
  211. if i < numValues-1 {
  212. ci.buffer = append(ci.buffer, '\t')
  213. }
  214. }
  215. ci.buffer = append(ci.buffer, '\n')
  216. if len(ci.buffer) > ciBufferFlushSize {
  217. ci.flush(ci.buffer)
  218. // reset buffer, keep bytes for message identifier and length
  219. ci.buffer = ci.buffer[:5]
  220. }
  221. return driver.RowsAffected(0), nil
  222. }
  223. func (ci *copyin) Close() (err error) {
  224. if ci.closed { // Don't do anything, we're already closed
  225. return nil
  226. }
  227. ci.closed = true
  228. if ci.isBad() {
  229. return driver.ErrBadConn
  230. }
  231. defer ci.cn.errRecover(&err)
  232. if len(ci.buffer) > 0 {
  233. ci.flush(ci.buffer)
  234. }
  235. // Avoid touching the scratch buffer as resploop could be using it.
  236. err = ci.cn.sendSimpleMessage('c')
  237. if err != nil {
  238. return err
  239. }
  240. <-ci.done
  241. ci.cn.inCopy = false
  242. if ci.isErrorSet() {
  243. err = ci.err
  244. return err
  245. }
  246. return nil
  247. }