diff --git a/go/logic/applier.go b/go/logic/applier.go index 9e336e2f5..4cb63ddb6 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -801,9 +801,15 @@ func (this *Applier) readMigrationMinValues(tx *gosql.Tx, uniqueKey *sql.UniqueK return err } } + if err := rows.Err(); err != nil { + return err + } this.migrationContext.Log.Infof("Migration min values: [%s]", this.migrationContext.MigrationRangeMinValues) + if this.migrationContext.MigrationRangeMinValues != nil { + this.migrationContext.MigrationRangeMinValues.NormalizeValues(uniqueKey.Columns) + } - return rows.Err() + return nil } // readMigrationMaxValues returns the maximum values to be iterated on rowcopy @@ -826,9 +832,15 @@ func (this *Applier) readMigrationMaxValues(tx *gosql.Tx, uniqueKey *sql.UniqueK return err } } + if err := rows.Err(); err != nil { + return err + } this.migrationContext.Log.Infof("Migration max values: [%s]", this.migrationContext.MigrationRangeMaxValues) + if this.migrationContext.MigrationRangeMaxValues != nil { + this.migrationContext.MigrationRangeMaxValues.NormalizeValues(uniqueKey.Columns) + } - return rows.Err() + return nil } // ReadMigrationRangeValues reads min/max values that will be used for rowcopy. @@ -911,6 +923,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo } if hasFurtherRange { this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues + this.migrationContext.MigrationIterationRangeMaxValues.NormalizeValues(this.migrationContext.UniqueKey.Columns) return hasFurtherRange, nil } } diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 97895890d..da8bc8fcc 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -748,6 +748,9 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL column.Type = sql.BinaryColumnType column.BinaryOctetLength = columnOctetLength } + if strings.HasPrefix(columnType, "bit") { + column.Type = sql.BitColumnType + } if strings.Contains(extra, " GENERATED") { column.IsVirtual = true } diff --git a/go/sql/builder.go b/go/sql/builder.go index 940ca4ca3..f9fc30627 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -58,6 +58,8 @@ func buildColumnsPreparedValues(columns *ColumnList) []string { token = fmt.Sprintf("ELT(?, %s)", column.EnumValues) } else if column.Type == JSONColumnType { token = "convert(? using utf8mb4)" + } else if column.Type == BitColumnType { + token = "cast(? as unsigned)" } else { token = "?" } @@ -340,6 +342,7 @@ func BuildUniqueKeyRangeEndPreparedQueryViaOffset(databaseName, tableName string if includeRangeStartValues { startRangeComparisonSign = GreaterThanOrEqualsComparisonSign } + rangeStartComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeStartArgs, startRangeComparisonSign) if err != nil { return "", explodedArgs, err diff --git a/go/sql/types.go b/go/sql/types.go index 1a8f8a2e2..b93fe553c 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -24,6 +24,7 @@ const ( JSONColumnType FloatColumnType BinaryColumnType + BitColumnType ) const maxMediumintUnsigned int32 = 16777215 @@ -95,6 +96,15 @@ func (this *Column) convertArg(arg interface{}) interface{} { } } + // We convert BIT col to uint64 to force correct value comparison. + if this.Type == BitColumnType { + var n uint64 + for _, b := range arg2Bytes { + n = (n << 8) | uint64(b) + } + arg = n + } + return arg } @@ -342,6 +352,12 @@ func (this *ColumnValues) AbstractValues() []interface{} { return this.abstractValues } +func (this *ColumnValues) NormalizeValues(columns ColumnList) { + for i, col := range columns.Columns() { + this.abstractValues[i] = col.convertArg(this.abstractValues[i]) + } +} + func (this *ColumnValues) StringColumn(index int) string { val := this.AbstractValues()[index] if ints, ok := val.([]uint8); ok { diff --git a/go/sql/types_test.go b/go/sql/types_test.go index 9275bbb85..cec2f2a74 100644 --- a/go/sql/types_test.go +++ b/go/sql/types_test.go @@ -156,3 +156,13 @@ func TestConvertArgBinaryColumnNoPaddingWhenFull(t *testing.T) { require.Equal(t, 20, len(resultBytes)) require.Equal(t, fullValue, resultBytes) } + +func TestConvertArgBitColumn(t *testing.T) { + b := []uint8{0x00, 0x00, 0xa3} + col := Column{ + Name: "bit_col", + Type: BitColumnType, + } + result := col.convertArg(b) + require.Equal(t, uint64(163), result) +} diff --git a/localtests/bit-unique-key/create.sql b/localtests/bit-unique-key/create.sql new file mode 100644 index 000000000..0497ae429 --- /dev/null +++ b/localtests/bit-unique-key/create.sql @@ -0,0 +1,9 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + `id` bigint not null, + `bit_col` bit not null, + primary key (`id`, `bit_col`) +) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; +insert into gh_ost_test values (1, b'1'); +insert into gh_ost_test values (2, b'1'); +insert into gh_ost_test values (3, b'1');