Scribunto: Synchronize latest of ustring.lua [#609]

pull/620/head
gnosygnu 5 years ago
parent af9d6c3a92
commit bdb1945d4f

@ -76,14 +76,14 @@ end
-- @param s string utf8-encoded string to decode
-- @return table
local function utf8_explode( s )
-- xowa:PERF:if s equals previous_string, return previous_table; handles loops such as:
-- for idx = 1, mw.ustring.len( from ) do
-- if (cp == mw.ustring.codepoint( from, idx)) then
if (s == utf8_explode__previous_string) then
return utf8_explode__previous_table;
end
-- xowa:PERF:if s equals previous_string, return previous_table; handles loops such as:
-- for idx = 1, mw.ustring.len( from ) do
-- if (cp == mw.ustring.codepoint( from, idx)) then
if (s == utf8_explode__previous_string) then
return utf8_explode__previous_table;
end
local ret = {
len = 0,
codepoints = {},
@ -157,8 +157,8 @@ local function utf8_explode( s )
ret.bytepos[#ret.bytepos + 1] = l + 1
ret.bytepos[#ret.bytepos + 1] = l + 1
utf8_explode__previous_string = s;
utf8_explode__previous_table = ret;
utf8_explode__previous_string = s;
utf8_explode__previous_table = ret;
return ret
end
@ -280,9 +280,32 @@ function ustring.gcodepoint( s, i, j )
checkString( 'gcodepoint', s )
checkType( 'gcodepoint', 2, i, 'number', true )
checkType( 'gcodepoint', 3, j, 'number', true )
local cp = { ustring.codepoint( s, i or 1, j or -1 ) }
local cps = utf8_explode( s )
if cps == nil then
error( "bad argument #1 for 'gcodepoint' (string is not UTF-8)", 2 )
end
i = i or 1
if i < 0 then
i = cps.len + i + 1
end
j = j or -1
if j < 0 then
j = cps.len + j + 1
end
if j < i then
return function ()
return nil
end
end
i = math.max( 1, math.min( i, cps.len + 1 ) )
j = math.max( 1, math.min( j, cps.len + 1 ) )
return function ()
return table.remove( cp, 1 )
if i <= j then
local ret = cps.codepoints[i]
i = i + 1
return ret
end
return nil
end
end
@ -570,7 +593,14 @@ local function find( s, cps, rawpat, pattern, init, noAnchor )
-- Returns the position after the set and a table holding the matching characters
parse_charset = function ( pp )
local _, ep
local epp = pattern.bytepos[pp]
local epp = pattern.bytepos[pp] + 1
if S.sub( rawpat, epp, epp ) == '^' then
epp = epp + 1
end
if S.sub( rawpat, epp, epp ) == ']' then
-- Lua's string module effectively does this
epp = epp + 1
end
repeat
_, ep = S.find( rawpat, ']', epp, true )
if not ep then
@ -593,9 +623,13 @@ local function find( s, cps, rawpat, pattern, init, noAnchor )
invert = true
pp = pp + 1
end
local first = true
while true do
local c = pattern.codepoints[pp]
if c == 0x25 then -- '%'
if not first and c == 0x5d then -- closing ']'
pp = pp + 1
break
elseif c == 0x25 then -- '%'
c = pattern.codepoints[pp + 1]
if charsets[c] then
csrefs[#csrefs + 1] = charsets[c]
@ -608,15 +642,13 @@ local function find( s, cps, rawpat, pattern, init, noAnchor )
cs[i] = 1
end
pp = pp + 3
elseif c == 0x5d then -- closing ']'
pp = pp + 1
break
elseif not c then -- Should never get here, but Just In Case...
error( 'Missing close-bracket', 3 )
else
cs[c] = 1
pp = pp + 1
end
first = false
end
local ret
@ -733,7 +765,7 @@ local function find( s, cps, rawpat, pattern, init, noAnchor )
local ep = match( sp, pp )
if ep then
for i = 1, ncapt do
captures[i] = getcapt( i, 'Unclosed capture beginning at pattern character ' .. captparen[pp], 2 )
captures[i] = getcapt( i, 'Unclosed capture beginning at pattern character ' .. captparen[i], 2 )
end
return sp, ep - 1, unpack( captures )
end
@ -748,22 +780,28 @@ end
-- inside brackets and aren't followed by quantifiers and aren't part of a
-- '%b', but that's too complicated to check.
-- * If it contains a negated character set.
-- * If it contains "%a" or any of the other %-prefixed character sets except
-- %z or %Z.
-- * If it contains a '.' not followed by '*', '+', or '-'. A bare '.' or '.?'
-- would try to match a partial UTF-8 character, but the others will happily
-- enough match a whole character thinking it's 2 or 4.
-- * If it contains "%a" or any of the other %-prefixed character sets except %z.
-- * If it contains a '.' not followed by '*', '+', '-'. A bare '.' or '.?'
-- matches a partial UTF-8 character, but the others will happily enough
-- match a whole UTF-8 character thinking it's 2, 3 or 4.
-- * If it contains position-captures.
-- * If it matches the empty string
--
-- @param string pattern
-- @return boolean
local function patternIsSimple( pattern )
local findWithPcall = function ( ... )
local ok, ret = pcall( S.find, ... )
return ok and ret
end
return not (
S.find( pattern, '[\128-\255]' ) or
S.find( pattern, '%[%^' ) or
S.find( pattern, '%%[acdlpsuwxACDLPSUWX]' ) or
S.find( pattern, '%.[^*+-]' ) or
S.find( pattern, '()', 1, true )
S.find( pattern, '%%[acdlpsuwxACDLPSUWXZ]' ) or
S.find( pattern, '%.[^*+-]' ) or S.find( pattern, '%.$' ) or
S.find( pattern, '()', 1, true ) or
pattern == '' or findWithPcall( '', pattern )
)
end
@ -808,12 +846,19 @@ function ustring.find( s, pattern, init, plain )
if init and init > cps.len + 1 then
init = cps.len + 1
end
local m = { S.find( s, pattern, cps.bytepos[init], plain ) }
local m
if plain then
m = { true, S.find( s, pattern, cps.bytepos[init], plain ) }
else
m = { pcall( S.find, s, pattern, cps.bytepos[init], plain ) }
end
if m[1] then
m[1] = cpoffset( cps, m[1] )
if m[2] then
m[2] = cpoffset( cps, m[2] )
m[3] = cpoffset( cps, m[3] )
end
return unpack( m, 2 )
end
return unpack( m )
end
return find( s, cps, pattern, pat, init )
@ -841,7 +886,10 @@ function ustring.match( s, pattern, init )
end
if patternIsSimple( pattern ) then
return S.match( s, pattern, cps.bytepos[init] )
local ret = { pcall( S.match, s, pattern, cps.bytepos[init] ) }
if ret[1] then
return unpack( ret, 2 )
end
end
local m = { find( s, cps, pattern, pat, init ) }
@ -867,7 +915,10 @@ function ustring.gmatch( s, pattern )
checkString( 'gmatch', s )
checkPattern( 'gmatch', pattern )
if patternIsSimple( pattern ) then
return S.gmatch( s, pattern )
local ret = { pcall( S.gmatch, s, pattern ) }
if ret[1] then
return unpack( ret, 2 )
end
end
local cps = utf8_explode( s )
@ -908,7 +959,18 @@ function ustring.gsub( s, pattern, repl, n )
checkPattern( 'gsub', pattern )
checkType( 'gsub', 4, n, 'number', true )
if patternIsSimple( pattern ) then
return S.gsub( s, pattern, repl, n )
local ret = { pcall( S.gsub, s, pattern, repl, n ) }
if ret[1] then
return unpack( ret, 2 )
end
end
if n == nil then
n = 1e100
end
if n < 1 then
-- No replacement
return s, 0
end
local cps = utf8_explode( s )
@ -919,9 +981,6 @@ function ustring.gsub( s, pattern, repl, n )
if pat == nil then
error( "bad argument #2 for 'gsub' (string is not UTF-8)", 2 )
end
if n == nil then
n = 1e100
end
if pat.codepoints[1] == 0x5e then -- '^': Pattern is anchored
-- There can be only the one match, so make that explicit
@ -945,8 +1004,9 @@ function ustring.gsub( s, pattern, repl, n )
local init = 1
local ct = 0
local ret = {}
while init < cps.len and ct < n do
local m = { find( s, cps, pattern, pat, init ) }
local zeroAdjustment = 0
repeat
local m = { find( s, cps, pattern, pat, init + zeroAdjustment ) }
if not m[1] then
break
end
@ -954,15 +1014,20 @@ function ustring.gsub( s, pattern, repl, n )
ret[#ret + 1] = sub( s, cps, init, m[1] - 1 )
end
local mm = sub( s, cps, m[1], m[2] )
local val
-- This simplifies the code for the function and table cases (tp == 1 and tp == 2) when there are
-- no captures in the pattern. As documented it would be incorrect for the string case by making
-- %1 act like %0 instead of raising an "invalid capture index" error, but Lua in fact does
-- exactly that for string.gsub.
if #m < 3 then
m[3] = mm
end
local val, valType
if tp == 1 then
if m[3] then
val = repl( unpack( m, 3 ) )
else
val = repl( mm )
end
val = repl( unpack( m, 3 ) )
elseif tp == 2 then
val = repl[m[3] or mm]
val = repl[m[3]]
elseif tp == 3 then
if ct == 0 and #m < 11 then
local ss = S.gsub( repl, '%%[%%0-' .. ( #m - 2 ) .. ']', 'x' )
@ -986,10 +1051,15 @@ function ustring.gsub( s, pattern, repl, n )
}
val = S.gsub( repl, '%%[%%0-9]', t )
end
valType = type( val )
if valType ~= 'nil' and valType ~= 'string' and valType ~= 'number' then
error( 'invalid replacement value (a ' .. valType .. ')', 2 )
end
ret[#ret + 1] = val or mm
init = m[2] + 1
ct = ct + 1
end
zeroAdjustment = m[2] < m[1] and 1 or 0
until init > cps.len or ct >= n
if init <= cps.len then
ret[#ret + 1] = sub( s, cps, init, cps.len )
end
@ -999,14 +1069,14 @@ end
---- Unicode Normalization ----
-- These functions load a conversion table when called
local function internalToNFD( cps )
local function internalDecompose( cps, decomp )
local cp = {}
local normal = require 'ustring/normalization-data'
-- Decompose into cp, using the lookup table and logic for hangul
for i = 1, cps.len do
local c = cps.codepoints[i]
local m = normal.decomp[c]
local m = decomp[c]
if m then
for j = 0, #m do
cp[#cp + 1] = m[j]
@ -1036,45 +1106,11 @@ local function internalToNFD( cps )
return cp, 1, l
end
-- Normalize a string to NFC
--
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
-- UTF-8.
--
-- @param s string
-- @return string|nil
function ustring.toNFC( s )
checkString( 'toNFC', s )
-- ASCII is always NFC
if not S.find( s, '[\128-\255]' ) then
return s
end
local cps = utf8_explode( s )
if cps == nil then
return nil
end
local function internalCompose( cp, _, l )
local normal = require 'ustring/normalization-data'
-- First, scan through to see if the string is definitely already NFC
local ok = true
for i = 1, cps.len do
local c = cps.codepoints[i]
if normal.check[c] then
ok = false
break
end
end
if ok then
return s
end
-- Next, expand to NFD
local cp, _, l = internalToNFD( cps )
-- Then combine to NFC. Since NFD->NFC can never expand a character
-- sequence, we can do this in-place.
-- Since NFD->NFC can never expand a character sequence, we can do this
-- in-place.
local comp = normal.comp[cp[1]]
local sc = 1
local j = 1
@ -1110,7 +1146,45 @@ function ustring.toNFC( s )
end
end
return internalChar( cp, 1, j )
return cp, 1, j
end
-- Normalize a string to NFC
--
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
-- UTF-8.
--
-- @param s string
-- @return string|nil
function ustring.toNFC( s )
checkString( 'toNFC', s )
-- ASCII is always NFC
if not S.find( s, '[\128-\255]' ) then
return s
end
local cps = utf8_explode( s )
if cps == nil then
return nil
end
local normal = require 'ustring/normalization-data'
-- First, scan through to see if the string is definitely already NFC
local ok = true
for i = 1, cps.len do
local c = cps.codepoints[i]
if normal.check[c] then
ok = false
break
end
end
if ok then
return s
end
-- Next, expand to NFD then recompose
return internalChar( internalCompose( internalDecompose( cps, normal.decomp ) ) )
end
-- Normalize a string to NFD
@ -1123,7 +1197,31 @@ end
function ustring.toNFD( s )
checkString( 'toNFD', s )
-- ASCII is always NFC
-- ASCII is always NFD
if not S.find( s, '[\128-\255]' ) then
return s
end
local cps = utf8_explode( s )
if cps == nil then
return nil
end
local normal = require 'ustring/normalization-data'
return internalChar( internalDecompose( cps, normal.decomp ) )
end
-- Normalize a string to NFKC
--
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
-- UTF-8.
--
-- @param s string
-- @return string|nil
function ustring.toNFKC( s )
checkString( 'toNFKC', s )
-- ASCII is always NFKC
if not S.find( s, '[\128-\255]' ) then
return s
end
@ -1132,8 +1230,34 @@ function ustring.toNFD( s )
if cps == nil then
return nil
end
local normal = require 'ustring/normalization-data'
-- Next, expand to NFKD then recompose
return internalChar( internalCompose( internalDecompose( cps, normal.decompK ) ) )
end
-- Normalize a string to NFKD
--
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
-- UTF-8.
--
-- @param s string
-- @return string|nil
function ustring.toNFKD( s )
checkString( 'toNFKD', s )
return internalChar( internalToNFD( cps ) )
-- ASCII is always NFKD
if not S.find( s, '[\128-\255]' ) then
return s
end
local cps = utf8_explode( s )
if cps == nil then
return nil
end
local normal = require 'ustring/normalization-data'
return internalChar( internalDecompose( cps, normal.decompK ) )
end
return ustring

Loading…
Cancel
Save