mirror of
https://gh.wpcy.net/https://github.com/discourse/discourse.git
synced 2026-05-23 23:04:48 +08:00
This replaces the existing implementation for usernames and group names with a generic base class and multiple implementations for users, groups, categories and badges. This fixes problems of the old implementation like endless loops or exceeding of max length.
153 lines
4.2 KiB
Ruby
Vendored
153 lines
4.2 KiB
Ruby
Vendored
# frozen_string_literal: true
|
|
|
|
module Migrations::Importer
|
|
class DiscourseDB
|
|
QueryResult = Data.define(:rows, :column_count)
|
|
|
|
COPY_BATCH_SIZE = 1_000
|
|
SKIP_ROW_MARKER = :"$skip"
|
|
|
|
def initialize
|
|
@encoder = PG::TextEncoder::CopyRow.new
|
|
@connection = PG::Connection.new(database_configuration)
|
|
@connection.type_map_for_results = PG::BasicTypeMapForResults.new(@connection)
|
|
end
|
|
|
|
def copy_data(table_name, column_names, rows)
|
|
quoted_column_name_list = column_names.map { |c| quote_identifier(c) }.join(",")
|
|
sql = "COPY #{table_name} (#{quoted_column_name_list}) FROM STDIN"
|
|
|
|
inserted_rows = []
|
|
skipped_rows = []
|
|
column_count = column_names.size
|
|
data = Array.new(column_count)
|
|
|
|
rows.each_slice(COPY_BATCH_SIZE) do |sliced_rows|
|
|
# TODO Maybe add error handling and check if all rows fail to insert, or only
|
|
# some of them fail. Currently, if a single row fails to insert, then an exception
|
|
# will stop the whole import. Which seems fine because ideally the import script
|
|
# should ensure all data is valid. We might need to see how this works out in
|
|
# actual migrations...
|
|
@connection.transaction do
|
|
@connection.copy_data(sql, @encoder) do
|
|
sliced_rows.each do |row|
|
|
if row[SKIP_ROW_MARKER]
|
|
skipped_rows << row
|
|
next
|
|
end
|
|
|
|
i = 0
|
|
while i < column_count
|
|
data[i] = row[column_names[i]]
|
|
i += 1
|
|
end
|
|
|
|
@connection.put_copy_data(data)
|
|
inserted_rows << row
|
|
end
|
|
end
|
|
|
|
# give the caller a chance to do some work when a batch has been committed,
|
|
# for example, to store ID mappings
|
|
yield inserted_rows, skipped_rows
|
|
|
|
inserted_rows.clear
|
|
skipped_rows.clear
|
|
end
|
|
end
|
|
|
|
nil
|
|
end
|
|
|
|
def last_id_of(table_name)
|
|
query = <<~SQL
|
|
SELECT COALESCE(MAX(id), 0)
|
|
FROM #{quote_identifier(table_name)}
|
|
WHERE id > 0
|
|
SQL
|
|
|
|
result = @connection.exec(query)
|
|
result.getvalue(0, 0)
|
|
end
|
|
|
|
def fix_last_id_of(table_name)
|
|
table_name = quote_identifier(table_name)
|
|
query = <<~SQL
|
|
SELECT SETVAL(PG_GET_SERIAL_SEQUENCE('#{table_name}', 'id'), MAX(id))
|
|
FROM #{table_name}
|
|
HAVING MAX(id) > 0
|
|
SQL
|
|
|
|
@connection.exec(query)
|
|
nil
|
|
end
|
|
|
|
def column_names(table_name)
|
|
query = <<~SQL
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_name = $1
|
|
ORDER BY ordinal_position
|
|
SQL
|
|
|
|
result = @connection.exec_params(query, [table_name])
|
|
result.column_values(0).map(&:to_sym)
|
|
end
|
|
|
|
def query_array(sql, *params)
|
|
query_result(sql, *params).rows
|
|
end
|
|
|
|
def query_result(sql, *params)
|
|
@connection.send_query_params(sql, params)
|
|
@connection.set_single_row_mode
|
|
|
|
first_result = @connection.get_result
|
|
return QueryResult.new(rows: Enumerator.new {}, column_count: 0) unless first_result
|
|
|
|
column_count = first_result.nfields
|
|
single_column = column_count == 1
|
|
|
|
rows_enumerator =
|
|
Enumerator.new do |y|
|
|
first_result.stream_each_row { |row| single_column ? y << row[0] : y << row }
|
|
first_result.clear
|
|
|
|
while (result = @connection.get_result)
|
|
result.stream_each_row { |row| single_column ? y << row[0] : y << row }
|
|
result.clear
|
|
end
|
|
end
|
|
|
|
QueryResult.new(rows: rows_enumerator, column_count:)
|
|
end
|
|
|
|
def close
|
|
@connection.finish
|
|
end
|
|
|
|
private
|
|
|
|
def database_configuration
|
|
db_config = ActiveRecord::Base.connection_db_config.configuration_hash
|
|
|
|
# credentials for PostgreSQL in CI environment
|
|
if Rails.env.test?
|
|
username = ENV["PGUSER"]
|
|
password = ENV["PGPASSWORD"]
|
|
end
|
|
|
|
{
|
|
host: db_config[:host],
|
|
port: db_config[:port],
|
|
user: db_config[:username] || username,
|
|
password: db_config[:password] || password,
|
|
dbname: db_config[:database],
|
|
}.compact
|
|
end
|
|
|
|
def quote_identifier(identifier)
|
|
PG::Connection.quote_ident(identifier.to_s)
|
|
end
|
|
end
|
|
end
|