// (C) Copyright 2017, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "networkio.h"
#include "include_gunit.h"
#include "stridemap.h"
#ifdef INCLUDE_TENSORFLOW
#  include <tensorflow/compiler/xla/array2d.h> // for xla::Array2D
#endif

namespace tesseract {

class NetworkioTest : public ::testing::Test {
protected:
  void SetUp() override {
    std::locale::global(std::locale(""));
  }

#ifdef INCLUDE_TENSORFLOW
  // Sets up an Array2d object of the given size, initialized to increasing
  // values starting with start.
  std::unique_ptr<xla::Array2D<int>> SetupArray(int ysize, int xsize, int start) {
    std::unique_ptr<xla::Array2D<int>> a(new xla::Array2D<int>(ysize, xsize));
    int value = start;
    for (int y = 0; y < ysize; ++y) {
      for (int x = 0; x < xsize; ++x) {
        (*a)(y, x) = value++;
      }
    }
    return a;
  }
  // Sets up a NetworkIO with a batch of 2 "images" of known values.
  void SetupNetworkIO(NetworkIO *nio) {
    std::vector<std::unique_ptr<xla::Array2D<int>>> arrays;
    arrays.push_back(SetupArray(3, 4, 0));
    arrays.push_back(SetupArray(4, 5, 12));
    std::vector<std::pair<int, int>> h_w_sizes;
    for (size_t i = 0; i < arrays.size(); ++i) {
      h_w_sizes.emplace_back(arrays[i].get()->height(), arrays[i].get()->width());
    }
    StrideMap stride_map;
    stride_map.SetStride(h_w_sizes);
    nio->ResizeToMap(true, stride_map, 2);
    // Iterate over the map, setting nio's contents from the arrays.
    StrideMap::Index index(stride_map);
    do {
      int value = (*arrays[index.index(FD_BATCH)])(index.index(FD_HEIGHT), index.index(FD_WIDTH));
      nio->SetPixel(index.t(), 0, 128 + value, 0.0f, 128.0f);
      nio->SetPixel(index.t(), 1, 128 - value, 0.0f, 128.0f);
    } while (index.Increment());
  }
#endif
};

// Tests that the initialization via SetPixel works and the resize correctly
// fills with zero where image sizes don't match.
TEST_F(NetworkioTest, InitWithZeroFill) {
#ifdef INCLUDE_TENSORFLOW
  NetworkIO nio;
  nio.Resize2d(true, 32, 2);
  int width = nio.Width();
  for (int t = 0; t < width; ++t) {
    nio.SetPixel(t, 0, 0, 0.0f, 128.0f);
    nio.SetPixel(t, 1, 0, 0.0f, 128.0f);
  }
  // The initialization will wipe out all previously set values.
  SetupNetworkIO(&nio);
  nio.ZeroInvalidElements();
  StrideMap::Index index(nio.stride_map());
  int next_t = 0;
  int pos = 0;
  do {
    int t = index.t();
    // The indexed values just increase monotonically.
    int value = nio.i(t)[0];
    EXPECT_EQ(value, pos);
    value = nio.i(t)[1];
    EXPECT_EQ(value, -pos);
    // When we skip t values, the data is always 0.
    while (next_t < t) {
      EXPECT_EQ(nio.i(next_t)[0], 0);
      EXPECT_EQ(nio.i(next_t)[1], 0);
      ++next_t;
    }
    ++pos;
    ++next_t;
  } while (index.Increment());
  EXPECT_EQ(pos, 32);
  EXPECT_EQ(next_t, 40);
#else
  LOG(INFO) << "Skip test because of missing xla::Array2D";
  GTEST_SKIP();
#endif
}

// Tests that CopyWithYReversal works.
TEST_F(NetworkioTest, CopyWithYReversal) {
#ifdef INCLUDE_TENSORFLOW
  NetworkIO nio;
  SetupNetworkIO(&nio);
  NetworkIO copy;
  copy.CopyWithYReversal(nio);
  StrideMap::Index index(copy.stride_map());
  int next_t = 0;
  int pos = 0;
  std::vector<int> expected_values = {8,  9,  10, 11, 4,  5,  6,  7,  0,  1,  2,
                                      3,  27, 28, 29, 30, 31, 22, 23, 24, 25, 26,
                                      17, 18, 19, 20, 21, 12, 13, 14, 15, 16};
  do {
    int t = index.t();
    // The indexed values match the expected values.
    int value = copy.i(t)[0];
    EXPECT_EQ(value, expected_values[pos]);
    value = copy.i(t)[1];
    EXPECT_EQ(value, -expected_values[pos]);
    // When we skip t values, the data is always 0.
    while (next_t < t) {
      EXPECT_EQ(copy.i(next_t)[0], 0) << "Failure t = " << next_t;
      EXPECT_EQ(copy.i(next_t)[1], 0) << "Failure t = " << next_t;
      ++next_t;
    }
    ++pos;
    ++next_t;
  } while (index.Increment());
  EXPECT_EQ(pos, 32);
  EXPECT_EQ(next_t, 40);
#else
  LOG(INFO) << "Skip test because of missing xla::Array2D";
  GTEST_SKIP();
#endif
}

// Tests that CopyWithXReversal works.
TEST_F(NetworkioTest, CopyWithXReversal) {
#ifdef INCLUDE_TENSORFLOW
  NetworkIO nio;
  SetupNetworkIO(&nio);
  NetworkIO copy;
  copy.CopyWithXReversal(nio);
  StrideMap::Index index(copy.stride_map());
  int next_t = 0;
  int pos = 0;
  std::vector<int> expected_values = {3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,
                                      8,  16, 15, 14, 13, 12, 21, 20, 19, 18, 17,
                                      26, 25, 24, 23, 22, 31, 30, 29, 28, 27};
  do {
    int t = index.t();
    // The indexed values match the expected values.
    int value = copy.i(t)[0];
    EXPECT_EQ(value, expected_values[pos]);
    value = copy.i(t)[1];
    EXPECT_EQ(value, -expected_values[pos]);
    // When we skip t values, the data is always 0.
    while (next_t < t) {
      EXPECT_EQ(copy.i(next_t)[0], 0) << "Failure t = " << next_t;
      EXPECT_EQ(copy.i(next_t)[1], 0) << "Failure t = " << next_t;
      ++next_t;
    }
    ++pos;
    ++next_t;
  } while (index.Increment());
  EXPECT_EQ(pos, 32);
  EXPECT_EQ(next_t, 40);
#else
  LOG(INFO) << "Skip test because of missing xla::Array2D";
  GTEST_SKIP();
#endif
}

// Tests that CopyWithXYTranspose works.
TEST_F(NetworkioTest, CopyWithXYTranspose) {
#ifdef INCLUDE_TENSORFLOW
  NetworkIO nio;
  SetupNetworkIO(&nio);
  NetworkIO copy;
  copy.CopyWithXYTranspose(nio);
  StrideMap::Index index(copy.stride_map());
  int next_t = 0;
  int pos = 0;
  std::vector<int> expected_values = {0,  4,  8,  1,  5,  9,  2,  6,  10, 3,  7,
                                      11, 12, 17, 22, 27, 13, 18, 23, 28, 14, 19,
                                      24, 29, 15, 20, 25, 30, 16, 21, 26, 31};
  do {
    int t = index.t();
    // The indexed values match the expected values.
    int value = copy.i(t)[0];
    EXPECT_EQ(value, expected_values[pos]);
    value = copy.i(t)[1];
    EXPECT_EQ(value, -expected_values[pos]);
    // When we skip t values, the data is always 0.
    while (next_t < t) {
      EXPECT_EQ(copy.i(next_t)[0], 0);
      EXPECT_EQ(copy.i(next_t)[1], 0);
      ++next_t;
    }
    ++pos;
    ++next_t;
  } while (index.Increment());
  EXPECT_EQ(pos, 32);
  EXPECT_EQ(next_t, 40);
#else
  LOG(INFO) << "Skip test because of missing xla::Array2D";
  GTEST_SKIP();
#endif
}

} // namespace tesseract
