Forecasting El Niño-Southern Oscillation (ENSO)

At this time, we use the convLSTM launched in a previous submit to foretell El Niño-Southern Oscillation (ENSO).

ENSO refers to a altering sample of sea floor temperatures and sea-level pressures occurring within the equatorial Pacific. From its three general states, most likely the best-known is El Niño. El Niño happens when floor water temperatures within the japanese Pacific are larger than regular, and the robust winds that usually blow from east to west are unusually weak. The other situations are termed La Niña. The whole lot in-between is classed as regular.

ENSO has nice impression on the climate worldwide, and routinely harms ecosystems and societies by means of storms, droughts and flooding, probably leading to famines and financial crises. One of the best societies can do is attempt to adapt and mitigate extreme penalties. Such efforts are aided by correct forecasts, the additional forward the higher.

Right here, deep studying (DL) can doubtlessly assist: Variables like sea floor temperatures and pressures are given on a spatial grid – that of the earth – and as we all know, DL is nice at extracting spatial (e.g., picture) options. For ENSO prediction, architectures like convolutional neural networks (Ham, Kim, and Luo (2019a)) or convolutional-recurrent hybrids are habitually used. One such hybrid is simply our convLSTM; it operates on sequences of options given on a spatial grid. At this time, thus, we’ll be coaching a mannequin for ENSO forecasting. This mannequin could have a convLSTM for its central ingredient.

Earlier than we begin, a notice. Whereas our mannequin suits properly with architectures described within the related papers, the identical can’t be mentioned for quantity of coaching information used. For causes of practicality, we use precise observations solely; consequently, we find yourself with a small (relative to the duty) dataset. In distinction, analysis papers have a tendency to utilize local weather simulations, leading to considerably extra information to work with.

From the outset, then, we don’t anticipate stellar efficiency. However, this could make for an attention-grabbing case examine, and a helpful code template for our readers to use to their very own information.

We are going to try and predict month-to-month common sea floor temperature within the Niño 3.4 area, as represented by the Niño 3.4 Index, plus categorization as one in all El Niño, La Niña or impartial. Predictions shall be primarily based on prior month-to-month sea floor temperatures spanning a big portion of the globe.

On the enter facet, public and ready-to-use information could also be downloaded from Tokyo Climate Center; as to prediction targets, we get hold of index and classification here.

Enter and goal information each are offered month-to-month. They intersect within the time interval starting from 1891-01-01 to 2020-08-01; so that is the vary of dates we’ll be zooming in on.

Enter: Sea Floor Temperatures

Month-to-month sea floor temperatures are offered in a latitude-longitude grid of decision 1°. Particulars of how the information have been processed can be found here.

Information information can be found in GRIB format; every file accommodates averages computed for a single month. We will both obtain particular person information or generate a text file of URLs for obtain. In case you’d prefer to comply with together with the submit, you’ll discover the contents of the textual content file I generated within the appendix. When you’ve saved these URLs to a file, you may have R get the information for you want so:

From R, we are able to learn GRIB information utilizing stars. For instance:

stars object with 2 dimensions and 1 attribute
 Min.   :-274.9  
 1st Qu.:-272.8  
 Median :-259.1  
 Imply   :-260.0  
 third Qu.:-248.4  
 Max.   :-242.8  
 NA's   :21001   
  from  to offset delta                       refsys level values    
x    1 360      0     1 Coordinate System importe...    NA   NULL [x]
y    1 180     90    -1 Coordinate System importe...    NA   NULL [y]

So on this GRIB file, now we have one attribute – which we all know to be sea floor temperature – on a two-dimensional grid. As to the latter, we are able to complement what stars tells us with additional information discovered within the documentation:

The east-west grid factors run eastward from 0.5ºE to 0.5ºW, whereas the north-south grid factors run northward from 89.5ºS to 89.5ºN.

We notice a couple of issues we’ll need to do with this information. For one, the temperatures appear to be given in Kelvin, however with minus indicators. We’ll take away the minus indicators and convert to levels Celsius for comfort. We’ll even have to consider what to do with the NAs that seem for all non-maritime coordinates.

Earlier than we get there although, we have to mix information from all information right into a single information body. This provides a further dimension, time, starting from 1891/01/01 to 2020/01/12:

grb <- read_stars(
  file.path(grb_dir, map(readLines("information", warn = FALSE), basename)), alongside = "time") %>%
                    values = seq(as.Date("1891-01-01"), as.Date("2020-12-01"), by = "months"),
                    names = "time"

stars object with 3 dimensions and 1 attribute
attribute(s), abstract of first 1e+05 cells:
 Min.   :-274.9  
 1st Qu.:-273.3  
 Median :-258.8  
 Imply   :-260.0  
 third Qu.:-247.8  
 Max.   :-242.8  
 NA's   :33724   
     from   to offset delta                       refsys level                    values    
x       1  360      0     1 Coordinate System importe...    NA                      NULL [x]
y       1  180     90    -1 Coordinate System importe...    NA                      NULL [y]
time    1 1560     NA    NA                         Date    NA 1891-01-01,...,2020-12-01    

Let’s visually examine the spatial distribution of month-to-month temperatures for one yr, 2020:

ggplot() +
  geom_stars(information = grb %>% filter(between(time, as.Date("2020-01-01"), as.Date("2020-12-01"))), alpha = 0.8) +
  facet_wrap("time") +
  scale_fill_viridis() +
  coord_equal() +
  theme_map() +
  theme( = "none") 

Monthly sea surface temperatures, 2020/01/01 - 2020/01/12.

Determine 1: Month-to-month sea floor temperatures, 2020/01/01 – 2020/01/12.

Goal: Niño 3.4 Index

For the Niño 3.4 Index, we obtain the month-to-month data and, among the many offered options, zoom in on two: the index itself (column NINO34_MEAN) and PHASE, which will be E (El Niño), L (La Niño) or N (impartial).

nino <- read_table2("ONI_NINO34_1854-2020.txt", skip = 9) %>%
  mutate(month = as.Date(paste0(YEAR, "-", `MON/MMM`, "-01"))) %>%
  choose(month, NINO34_MEAN, PHASE) %>%
  filter(between(month, as.Date("1891-01-01"), as.Date("2020-08-01"))) %>%
  mutate(phase_code = as.numeric(as.factor(PHASE)))


Subsequent, we have a look at methods to get the information right into a format handy for coaching and prediction.


First, we take away all enter information for time limits the place floor reality information are nonetheless lacking.

Subsequent, as is finished by e.g. Ham, Kim, and Luo (2019b), we solely use grid factors between 55° south and 60° north. This has the extra benefit of decreasing reminiscence necessities.

sst <- grb %>% filter(between(y,-55, 60))

360, 115, 1560

As already alluded to, with the little information now we have we are able to’t anticipate a lot when it comes to generalization. Nonetheless, we put aside a small portion of the information for validation, since we’d like for this submit to function a helpful template for use with greater datasets.

From right here on, we work with R arrays.

sst_train <- as.tbl_cube.stars(sst_train)$mets[[1]]
sst_valid <- as.tbl_cube.stars(sst_valid)$mets[[1]]

Conversion to levels Celsius is just not strictly essential, as preliminary experiments confirmed a slight efficiency improve as a result of normalizing the enter, and we’re going to try this anyway. Nonetheless, it reads nicer to people than Kelvin.

sst_train <- sst_train + 273.15
quantile(sst_train, na.rm = TRUE)
     0%     25%     50%     75%    100% 
-1.8000 12.9975 21.8775 26.8200 34.3700 

Under no circumstances surprisingly, world warming is obvious from inspecting temperature distribution on the validation set (which was chosen to span the final thirty-one years).

sst_valid <- sst_valid + 273.15
quantile(sst_valid, na.rm = TRUE)
    0%    25%    50%    75%   100% 
-1.800 13.425 22.335 27.240 34.870 

The following-to-last step normalizes each units in accordance with coaching imply and variance.

train_mean <- mean(sst_train, na.rm = TRUE)
train_sd <- sd(sst_train, na.rm = TRUE)

sst_train <- (sst_train - train_mean) / train_sd

sst_valid <- (sst_valid - train_mean) / train_sd

Lastly, what ought to we do in regards to the NA entries? We set them to zero, the (coaching set) imply. That is probably not sufficient of an motion although: It means we’re feeding the community roughly 30% deceptive information. That is one thing we’re not finished with but.

sst_train[] <- 0
sst_valid[] <- 0


The goal information are break up analogously. Let’s verify although: Are phases (categorizations) distributedly equally in each units?

nino_train <- nino %>% filter(month < as.Date("1990-01-01"))
nino_valid <- nino %>% filter(month >= as.Date("1990-01-01"))

nino_train %>% group_by(phase_code, PHASE) %>% summarise(rely = n(), avg = mean(NINO34_MEAN))
# A tibble: 3 x 4
# Teams:   phase_code [3]
  phase_code PHASE rely   avg
       <dbl> <chr> <int> <dbl>
1          1 E       301  27.7
2          2 L       333  25.6
3          3 N       554  26.7
nino_valid %>% group_by(phase_code, PHASE) %>% summarise(rely = n(), avg = mean(NINO34_MEAN))
# A tibble: 3 x 4
# Teams:   phase_code [3]
  phase_code PHASE rely   avg
       <dbl> <chr> <int> <dbl>
1          1 E        93  28.1
2          2 L        93  25.9
3          3 N       182  27.2

This doesn’t look too unhealthy. After all, we once more see the general rise in temperature, regardless of part.

Lastly, we normalize the index, identical as we did for the enter information.

train_mean_nino <- mean(nino_train$NINO34_MEAN)
train_sd_nino <- sd(nino_train$NINO34_MEAN)

nino_train <- nino_train %>% mutate(NINO34_MEAN = scale(NINO34_MEAN, middle = train_mean_nino, scale = train_sd_nino))
nino_valid <- nino_valid %>% mutate(NINO34_MEAN = scale(NINO34_MEAN, middle = train_mean_nino, scale = train_sd_nino))

On to the torch dataset.

The dataset is accountable for appropriately matching up inputs and targets.

Our aim is to take six months of world sea floor temperatures and predict the Niño 3.4 Index for the next month. Enter-wise, the mannequin will anticipate the next format semantics:

batch_size * timesteps * width * top * channels, the place

  • batch_size is the variety of observations labored on in a single spherical of computations,

  • timesteps chains consecutive observations from adjoining months,

  • width and top collectively represent the spatial grid, and

  • channels corresponds to obtainable visible channels within the “picture.”

In .getitem(), we choose the consecutive observations, beginning at a given index, and stack them in dimension one. (One, not two, as batches will solely begin to exist as soon as the dataloader comes into play.)

Now, what in regards to the goal? Our final aim was – is – predicting the Niño 3.4 Index. Nevertheless, as you see we outline three targets: One is the index, as anticipated; a further one holds the spatially-gridded sea floor temperatures for the prediction month. Why? Our foremost instrument, probably the most distinguished constituent of the mannequin, shall be a convLSTM, an structure designed for spatial prediction. Thus, to coach it effectively, we need to give it the chance to foretell values on a spatial grid. Up to now so good; however there’s yet another goal, the part/class. This was added for experimentation functions: Possibly predicting each index and part helps in coaching?

Lastly, right here is the code for the dataset. In our experiments, we primarily based predictions on inputs from the previous six months (n_timesteps <- 6). It is a parameter you would possibly need to play with, although.

n_timesteps <- 6

enso_dataset <- dataset(
  title = "enso_dataset",
  initialize = perform(sst, nino, n_timesteps) {
    self$sst <- sst
    self$nino <- nino
    self$n_timesteps <- n_timesteps
  .getitem = perform(i) {
    x <- torch_tensor(self$sst[, , i:(n_timesteps + i - 1)]) # (360, 115, n_timesteps)
    x <- x$permute(c(3,1,2))$unsqueeze(2) # (n_timesteps, 1, 360, 115))
    y1 <- torch_tensor(self$sst[, , n_timesteps + i])$unsqueeze(1) # (1, 360, 115)
    y2 <- torch_tensor(self$nino$NINO34_MEAN[n_timesteps + i])
    y3 <- torch_tensor(self$nino$phase_code[n_timesteps + i])$squeeze()$to(torch_long())
    list(x = x, y1 = y1, y2 = y2, y3 = y3)
  .size = perform() {
    nrow(self$nino) - n_timesteps

valid_ds <- enso_dataset(sst_valid, nino_valid, n_timesteps)

After the customized dataset, we create the – fairly typical – dataloaders, making use of a batch measurement of 4.

batch_size <- 4

train_dl <- train_ds %>% dataloader(batch_size = batch_size, shuffle = TRUE)

valid_dl <- valid_ds %>% dataloader(batch_size = batch_size)

Subsequent, we proceed to mannequin creation.

The mannequin’s foremost ingredient is the convLSTM launched in a prior post. For comfort, we reproduce the code within the appendix.

In addition to the convLSTM, the mannequin makes use of three convolutional layers, a batchnorm layer and 5 linear layers. The logic is the next.

First, the convLSTM job is to foretell the subsequent month’s sea floor temperatures on the spatial grid. For that, we nearly simply return its remaining state, – nearly: We use self$conv1 to cut back the quantity channels to 1.

For predicting index and part, we then have to flatten the grid, as we require a single worth every. That is the place the extra conv layers are available. We do hope they’ll help in studying, however we additionally need to cut back the variety of parameters a bit, downsizing the grid (strides = 2 and strides = 3, resp.) a bit earlier than the upcoming torch_flatten().

As soon as now we have a flat construction, studying is shared between the duties of index and part prediction (self$linear), till lastly their paths break up (self$cont and self$cat, resp.), they usually return their separate outputs.

(The batchnorm? I’ll touch upon that within the Discussion.)

mannequin <- nn_module(
  initialize = perform(channels_in,
                        convlstm_layers) {
    self$n_layers <- convlstm_layers
    self$convlstm <- convlstm(
      input_dim = channels_in,
      hidden_dims = convlstm_hidden,
      kernel_sizes = convlstm_kernel,
      n_layers = convlstm_layers
    self$conv1 <-
        in_channels = 32,
        out_channels = 1,
        kernel_size = 5,
        padding = 2
    self$conv2 <-
        in_channels = 32,
        out_channels = 32,
        kernel_size = 5,
        stride = 2
    self$conv3 <-
        in_channels = 32,
        out_channels = 32,
        kernel_size = 5,
        stride = 3
    self$linear <- nn_linear(33408, 64)
    self$b1 <- nn_batch_norm1d(num_features = 64)
    self$cont <- nn_linear(64, 128)
    self$cat <- nn_linear(64, 128)
    self$cont_output <- nn_linear(128, 1)
    self$cat_output <- nn_linear(128, 3)
  ahead = perform(x) {
    ret <- self$convlstm(x)
    layer_last_states <- ret[[2]]
    last_hidden <- layer_last_states[[self$n_layers]][[1]]
    next_sst <- last_hidden %>% self$conv1() 
    c2 <- last_hidden %>% self$conv2() 
    c3 <- c2 %>% self$conv3() 
    flat <- torch_flatten(c3, start_dim = 2)
    widespread <- self$linear(flat) %>% self$b3() %>% nnf_relu()

    next_temp <- widespread %>% self$cont() %>% nnf_relu() %>% self$cont_output()
    next_nino <- widespread %>% self$cat() %>% nnf_relu() %>% self$cat_output()
    list(next_sst, next_temp, next_nino)

Subsequent, we instantiate a fairly small-ish mannequin. You’re greater than welcome to experiment with bigger fashions, however coaching time in addition to GPU reminiscence necessities will improve.

internet <- mannequin(
  channels_in = 1,
  convlstm_hidden = c(16, 16, 32),
  convlstm_kernel = c(3, 3, 5),
  convlstm_layers = 3

machine <- torch_device(if (cuda_is_available()) "cuda" else "cpu")

internet <- internet$to(machine = machine)
An `nn_module` containing 2,389,605 parameters.

── Modules ───────────────────────────────────────────────────────────────────────────────
● convlstm: <nn_module> #182,080 parameters
● conv1: <nn_conv2d> #801 parameters
● conv2: <nn_conv2d> #25,632 parameters
● conv3: <nn_conv2d> #25,632 parameters
● linear: <nn_linear> #2,138,176 parameters
● b1: <nn_batch_norm1d> #128 parameters
● cont: <nn_linear> #8,320 parameters
● cat: <nn_linear> #8,320 parameters
● cont_output: <nn_linear> #129 parameters
● cat_output: <nn_linear> #387 parameters

We’ve got three mannequin outputs. How ought to we mix the losses?

Provided that the principle aim is predicting the index, and the opposite two outputs are basically means to an finish, I discovered the next mixture relatively efficient:

# weight for sea floor temperature prediction
lw_sst <- 0.2

# weight for prediction of El Nino 3.4 Index
lw_temp <- 0.4

# weight for part prediction
lw_nino <- 0.4

The coaching course of follows the sample seen in all torch posts thus far: For every epoch, loop over the coaching set, backpropagate, verify efficiency on validation set.

However, once we did the pre-processing, we have been conscious of an imminent drawback: the lacking temperatures for continental areas, which we set to zero. As a sole measure, this method is clearly inadequate. What if we had chosen to make use of latitude-dependent averages? Or interpolation? Each could also be higher than a worldwide common, however each have their issues as properly. Let’s not less than alleviate unfavourable penalties by not utilizing the respective pixels for spatial loss calculation. That is taken care of by the next line under:

sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])

Right here, then, is the whole coaching code.

optimizer <- optim_adam(internet$parameters, lr = 0.001)

num_epochs <- 50

train_batch <- perform(b) {
  output <- internet(b$x$to(machine = machine))
  sst_output <- output[[1]]
  sst_target <- b$y1$to(machine = machine)
  sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])
  temp_loss <- nnf_mse_loss(output[[2]], b$y2$to(machine = machine))
  nino_loss <- nnf_cross_entropy(output[[3]], b$y3$to(machine = machine))
  loss <- lw_sst * sst_loss + lw_temp * temp_loss + lw_nino * nino_loss

  list(sst_loss$merchandise(), temp_loss$merchandise(), nino_loss$merchandise(), loss$merchandise())

valid_batch <- perform(b) {
  output <- internet(b$x$to(machine = machine))
  sst_output <- output[[1]]
  sst_target <- b$y1$to(machine = machine)
  sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])
  temp_loss <- nnf_mse_loss(output[[2]], b$y2$to(machine = machine))
  nino_loss <- nnf_cross_entropy(output[[3]], b$y3$to(machine = machine))
  loss <-
    lw_sst * sst_loss + lw_temp * temp_loss + lw_nino * nino_loss


for (epoch in 1:num_epochs) {
  train_loss_sst <- c()
  train_loss_temp <- c()
  train_loss_nino <- c()
  train_loss <- c()

  coro::loop(for (b in train_dl) {
    losses <- train_batch(b)
    train_loss_sst <- c(train_loss_sst, losses[[1]])
    train_loss_temp <- c(train_loss_temp, losses[[2]])
    train_loss_nino <- c(train_loss_nino, losses[[3]])
    train_loss <- c(train_loss, losses[[4]])
      "nEpoch %d, coaching: loss: %3.3f sst: %3.3f temp: %3.3f nino: %3.3f n",
      epoch, mean(train_loss), mean(train_loss_sst), mean(train_loss_temp), mean(train_loss_nino)
  valid_loss_sst <- c()
  valid_loss_temp <- c()
  valid_loss_nino <- c()
  valid_loss <- c()

  coro::loop(for (b in valid_dl) {
    losses <- valid_batch(b)
    valid_loss_sst <- c(valid_loss_sst, losses[[1]])
    valid_loss_temp <- c(valid_loss_temp, losses[[2]])
    valid_loss_nino <- c(valid_loss_nino, losses[[3]])
    valid_loss <- c(valid_loss, losses[[4]])
      "nEpoch %d, validation: loss: %3.3f sst: %3.3f temp: %3.3f nino: %3.3f n",
      epoch, mean(valid_loss), mean(valid_loss_sst), mean(valid_loss_temp), mean(valid_loss_nino)
  torch_save(internet, paste0(
    "model_", epoch, "_", round(mean(train_loss), 3), "_", round(mean(valid_loss), 3), ".pt"

Once I ran this, efficiency on the coaching set decreased in a not-too-fast, however steady means, whereas validation set efficiency stored fluctuating. For reference, complete (composite) losses seemed like this:

Epoch     Coaching    Validation
   10        0.336         0.633
   20        0.233         0.295
   30        0.135         0.461
   40        0.099         0.903
   50        0.061         0.727

Pondering of the dimensions of the validation set – thirty-one years, or equivalently, 372 information factors – these fluctuations is probably not all too stunning.

Now losses are usually summary; let’s see what truly will get predicted. We get hold of predictions for index values and phases like so …


pred_index <- c()
pred_phase <- c()

coro::loop(for (b in valid_dl) {

  output <- internet(b$x$to(machine = machine))

  pred_index <- c(pred_index, output[[2]]$to(machine = "cpu"))
  pred_phase <- rbind(pred_phase, as.matrix(output[[3]]$to(machine = "cpu")))


… and mix these with the bottom reality, stripping off the primary six rows (six was the variety of timesteps used as predictors):

valid_perf <- data.frame(
  actual_temp = nino_valid$NINO34_MEAN[(batch_size + 1):nrow(nino_valid)] * train_sd_nino + train_mean_nino,
  actual_nino = factor(nino_valid$phase_code[(batch_size + 1):nrow(nino_valid)]),
  pred_temp = pred_index * train_sd_nino + train_mean_nino,
  pred_nino = factor(pred_phase %>% apply(1, which.max))

For the part, we are able to generate a confusion matrix:

yardstick::conf_mat(valid_perf, actual_nino, pred_nino)
Prediction   1   2   3
         1  70   0  43
         2   0  47  10
         3  23  46 123

This seems to be higher than anticipated (primarily based on the losses). Phases 1 and a couple of correspond to El Niño and La Niña, respectively, and these get sharply separated.

What in regards to the Niño 3.4 Index? Let’s plot predictions versus floor reality:

valid_perf <- valid_perf %>% 
  choose(precise = actual_temp, predicted = pred_temp) %>% 
  add_column(month = seq(as.Date("1990-07-01"), as.Date("2020-08-01"), by = "months")) %>%
  pivot_longer(-month, names_to = "Index", values_to = "temperature")

ggplot(valid_perf, aes(x = month, y = temperature, coloration = Index)) +
  geom_line() +
  scale_color_manual(values = c("#006D6F", "#B2FFFF")) +

Nino 3.4 Index: Ground truth vs. predictions (validation set).

Determine 2: Nino 3.4 Index: Floor reality vs. predictions (validation set).

This doesn’t look unhealthy both. Nevertheless, we have to understand that we’re predicting only a single time step forward. We most likely mustn’t overestimate the outcomes. Which leads on to the dialogue.

When working with small quantities of information, lots will be realized by quick-ish experimentation. Nevertheless, when on the identical time, the duty is complicated, one needs to be cautious extrapolating.

For instance, well-established regularizers equivalent to batchnorm and dropout, whereas supposed to enhance generalization to the validation set, might end up to severely impede coaching itself. That is the story behind the one batchnorm layer I stored (I did strive having extra), and additionally it is why there isn’t any dropout.

One lesson to study from this expertise then is: Ensure the quantity of information matches the complexity of the duty. That is what we see within the ENSO prediction papers revealed on arxiv.

If we should always deal with the outcomes with warning, why even publish the submit?

For one, it reveals an utility of convLSTM to real-world information, using a fairly complicated structure and illustrating methods like customized losses and loss masking. Comparable architectures and techniques needs to be relevant to a variety of real-world duties – principally, at any time when predictors in a time-series drawback are given on a spatial grid.

Secondly, the appliance itself – forecasting an atmospheric phenomenon that tremendously impacts ecosystems in addition to human well-being – looks as if a superb use of deep studying. Purposes like these stand out as all of the extra worthwhile as the identical can’t be mentioned of all the pieces deep studying is – and shall be, barring efficient regulation – used for.

Thanks for studying!

A1: Record of GRB information

To be put right into a textual content file to be used with purrr::stroll( … obtain.file … ).

A2: convlstm code

For an in-depth clarification of convlstm, see the blog post.


convlstm_cell <- nn_module(
  initialize = perform(input_dim, hidden_dim, kernel_size, bias) {
    self$hidden_dim <- hidden_dim
    padding <- kernel_size %/% 2
    self$conv <- nn_conv2d(
      in_channels = input_dim + self$hidden_dim,
      # for every of enter, overlook, output, and cell gates
      out_channels = 4 * self$hidden_dim,
      kernel_size = kernel_size,
      padding = padding,
      bias = bias
  ahead = perform(x, prev_states) {

    h_prev <- prev_states[[1]]
    c_prev <- prev_states[[2]]
    mixed <- torch_cat(list(x, h_prev), dim = 2)  # concatenate alongside channel axis
    combined_conv <- self$conv(mixed)
    gate_convs <- torch_split(combined_conv, self$hidden_dim, dim = 2)
    cc_i <- gate_convs[[1]]
    cc_f <- gate_convs[[2]]
    cc_o <- gate_convs[[3]]
    cc_g <- gate_convs[[4]]
    # enter, overlook, output, and cell gates (comparable to torch's LSTM)
    i <- torch_sigmoid(cc_i)
    f <- torch_sigmoid(cc_f)
    o <- torch_sigmoid(cc_o)
    g <- torch_tanh(cc_g)
    # cell state
    c_next <- f * c_prev + i * g
    # hidden state
    h_next <- o * torch_tanh(c_next)
    list(h_next, c_next)
  init_hidden = perform(batch_size, top, width) {
    list(torch_zeros(batch_size, self$hidden_dim, top, width, machine = self$conv$weight$machine),
         torch_zeros(batch_size, self$hidden_dim, top, width, machine = self$conv$weight$machine))

convlstm <- nn_module(
  initialize = perform(input_dim, hidden_dims, kernel_sizes, n_layers, bias = TRUE) {
    self$n_layers <- n_layers
    self$cell_list <- nn_module_list()
    for (i in 1:n_layers) {
      cur_input_dim <- if (i == 1) input_dim else hidden_dims[i - 1]
      self$cell_list$append(convlstm_cell(cur_input_dim, hidden_dims[i], kernel_sizes[i], bias))
  # we at all times assume batch-first
  ahead = perform(x) {
    batch_size <- x$measurement()[1]
    seq_len <- x$measurement()[2]
    top <- x$measurement()[4]
    width <- x$measurement()[5]
    # initialize hidden states
    init_hidden <- vector(mode = "listing", size = self$n_layers)
    for (i in 1:self$n_layers) {
      init_hidden[[i]] <- self$cell_list[[i]]$init_hidden(batch_size, top, width)
    # listing containing the outputs, of size seq_len, for every layer
    # this is similar as h, at every step within the sequence
    layer_output_list <- vector(mode = "listing", size = self$n_layers)
    # listing containing the final states (h, c) for every layer
    layer_state_list <- vector(mode = "listing", size = self$n_layers)

    cur_layer_input <- x
    hidden_states <- init_hidden
    # loop over layers
    for (i in 1:self$n_layers) {
      # each layer's hidden state begins from 0 (non-stateful)
      h_c <- hidden_states[[i]]
      h <- h_c[[1]]
      c <- h_c[[2]]
      # outputs, of size seq_len, for this layer
      # equivalently, listing of h states for every time step
      output_sequence <- vector(mode = "listing", size = seq_len)
      # loop over timesteps
      for (t in 1:seq_len) {
        h_c <- self$cell_list[[i]](cur_layer_input[ , t, , , ], list(h, c))
        h <- h_c[[1]]
        c <- h_c[[2]]
        # maintain observe of output (h) for each timestep
        # h has dim (batch_size, hidden_size, top, width)
        output_sequence[[t]] <- h

      # stack hs for all timesteps over seq_len dimension
      # stacked_outputs has dim (batch_size, seq_len, hidden_size, top, width)
      # identical as enter to ahead (x)
      stacked_outputs <- torch_stack(output_sequence, dim = 2)
      # move the listing of outputs (hs) to subsequent layer
      cur_layer_input <- stacked_outputs
      # maintain observe of listing of outputs or this layer
      layer_output_list[[i]] <- stacked_outputs
      # maintain observe of final state for this layer
      layer_state_list[[i]] <- list(h, c)
    list(layer_output_list, layer_state_list)
Ham, Yoo-Geun, Jeong-Hwan Kim, and Jing-Jia Luo. 2019b. Deep studying for multi-year ENSO forecasts 573 (7775): 568–72.
———. 2019a. Deep studying for multi-year ENSO forecasts 573 (7775): 568–72.

Leave a Reply

Your email address will not be published. Required fields are marked *