library(data.table)
n.segments <- 10
seg.mean.vec <- 1:n.segments
data.per.segment <- 10
data.mean.vec <- rep(seg.mean.vec, each=data.per.segment)
n.data <- length(data.mean.vec)
n.validation.sets <- 100
n.folds.vec <- c(10, 2)
prop.valid.vec <- 1/n.folds.vec
sim.result.list <- list()
if(interactive()){
for(data.seed in 1:100){
set.seed(data.seed)
data.vec <- rnorm(n.data, data.mean.vec, 0.1)
is.valid.vec.list <- list()
for(n.folds in n.folds.vec){
uniq.folds <- 1:n.folds
n.seeds <- n.validation.sets/n.folds
split.type <- sprintf("%d-fold %d times", n.folds, n.seeds)
for(seed in 1:n.seeds){
set.seed(seed)
fold.vec <- sample(rep(uniq.folds, l=n.data))
for(valid.fold in uniq.folds){
is.valid.vec.list[[split.type]][[paste(seed, valid.fold)]] <-
fold.vec==valid.fold
}
}
}
for(prop.valid in prop.valid.vec){
split.type <- sprintf("%d%% %d times", 100*prop.valid, n.validation.sets)
prop.vec <- c(subtrain=1-prop.valid, validation=prop.valid)
for(split.i in 1:n.validation.sets){
set.seed(split.i)
is.valid.vec.list[[split.type]][[split.i]] <- binsegRcpp::random_set_vec(
n.data, prop.vec) == "validation"
}
}
loss.dt <- CJ(split.i=1:n.validation.sets, type=names(is.valid.vec.list))[, {
is.valid <- is.valid.vec.list[[type]][[split.i]]
bs.model <- binsegRcpp::binseg_normal(data.vec, is.validation.vec=is.valid)
bs.model$splits[, data.table(
segments,
validation.loss)]
}, by=.(split.i, type)]
loss.stats <- loss.dt[, .(
mean.valid.loss=mean(validation.loss)
), by=.(type, segments)]
select.each.split <- loss.dt[
, .SD[which.min(validation.loss)],
by=.(type, split.i)]
selected.times <- select.each.split[, .(
times=.N
), by=.(type, segments)]
selected.segments <- rbind(
select.each.split[, .(
selected=min(segments)
), by=.(method=paste(type, "min err, min segs"))],
selected.times[, .(
selected=segments[which.max(times)]
), by=.(method=paste(type, "min err, max times"))],
loss.stats[, .(
selected=segments[which.min(mean.valid.loss)]
), by=.(method=paste(type, "mean err, min err"))]
)
sim.result.list[[data.seed]] <- data.table(
data.seed, selected.segments, n.segments)
}
sim.result <- do.call(rbind, sim.result.list)
(sim.err <- sim.result[, .(
zero.one.loss=sum(selected != n.segments),
L1.loss=sum(abs(selected-n.segments)),
L2.loss=sum((selected-n.segments)^2)
), by=method][order(zero.one.loss)])
plot(data.vec)
}
The code above compares several types of cross-validation for selecting the number of segments in simulated random normal data. The table above shows various error rates which compare the selected number of segments to the true number of segments in the simulation. The best methods appears to be the ones which use min err, max times.
n.segments <- 20
seg.mean.vec <- 1:n.segments
data.per.segment <- 5
data.mean.vec <- rep(seg.mean.vec, each=data.per.segment)
n.data <- length(data.mean.vec)
n.validation.sets <- 200
prop.valid <- c(0.01, 0.05, 0.1, 0.25, 0.5)
if(interactive()){
sim.result <- data.table(data.seed=1:100)[, {
set.seed(data.seed)
data.vec <- rnorm(n.data, data.mean.vec, 0.1)
select.each.split <- CJ(split.i=1:n.validation.sets, prop.valid)[, {
set.seed(split.i)
prop.sets <- c(subtrain=1-prop.valid, validation=prop.valid)
is.valid <- binsegRcpp::random_set_vec(
n.data, prop.sets)=="validation"
bs.model <- binsegRcpp::binseg_normal(
data.vec, is.validation.vec=is.valid)
bs.model$splits[, .(selected=segments[which.min(validation.loss)])]
}, by=.(split.i, prop.valid)]
data.table(n.splits=1:n.validation.sets)[, {
select.each.split[split.i <= n.splits, .(
times=.N
),
by=.(prop.valid, selected)
][, .SD[which.max(times), .(selected)], by=prop.valid]
}, by=n.splits]
}, by=data.seed]
if(require(ggplot2)){
ggplot()+
scale_color_gradient(low="red", high="black")+
geom_line(aes(
n.splits, selected,
group=paste(data.seed, prop.valid),
color=prop.valid),
data=sim.result)+
scale_y_continuous(breaks=seq(0, 100, by=10))
}
accuracy.dt <- sim.result[, .(
correct=sum(selected==n.segments)
), by=.(prop.valid, n.splits)]
if(require(ggplot2)){
gg <- ggplot()+
geom_line(aes(
n.splits, correct,
group=prop.valid,
color=prop.valid),
size=2,
data=accuracy.dt)+
scale_y_continuous("number of correctly chosen data sets")+
scale_color_gradient(low="red", high="black")
if(require(directlabels)){
direct.label(gg, "right.polygons")
}else{
gg
}
}
}
The plot above suggests that 100 validation sets is sufficient.
Need a high-speed mirror for your open-source project?
Contact our mirror admin team at info@clientvps.com.
This archive is provided as a free public service to the community.
Proudly supported by infrastructure from VPSPulse , RxServers , BuyNumber , UnitVPS , OffshoreName and secure payment technology by ArionPay.