diff --git a/src/mlpack/methods/ann/gan.hpp b/src/mlpack/methods/ann/gan.hpp index f0cf7a743..b76e6b613 100644 --- a/src/mlpack/methods/ann/gan.hpp +++ b/src/mlpack/methods/ann/gan.hpp @@ -30,9 +30,9 @@ using namespace mlpack::distribution; namespace mlpack { namespace ann /** artifical neural network **/ { template< -typename Model = FFN>, -typename InitializationRuleType = GaussianInitialization, -class Noise = std::normal_distribution<>> +typename Model, +typename InitializationRuleType, +class Noise> class GAN { public: diff --git a/src/mlpack/methods/ann/gan_impl.hpp b/src/mlpack/methods/ann/gan_impl.hpp index 0b787443c..97f945979 100644 --- a/src/mlpack/methods/ann/gan_impl.hpp +++ b/src/mlpack/methods/ann/gan_impl.hpp @@ -136,7 +136,7 @@ double GAN::Evaluate( std::move(boost::apply_visitor( outputParameterVisitor, discriminator.network.back())), std::move(currentTarget)); - noise.imbue( [&]() { return noiseFunction(randGen);} ); + noise.imbue( [&]() { return noiseFunction();} ); generator.Forward(std::move(noise)); arma::mat temp = boost::apply_visitor( outputParameterVisitor, generator.network.back()); @@ -193,7 +193,7 @@ Gradient(const arma::mat& /*parameters*/, const size_t i, arma::mat& gradient) // get the gradients of the discriminator discriminator.Gradient(discriminator.parameter, i, gradientDiscriminator); - noise.imbue( [&]() { return noiseFunction(randGen);} ); + noise.imbue( [&]() { return noiseFunction();} ); generator.Forward(std::move(noise)); discriminator.predictors.col(numFunctions) = boost::apply_visitor( outputParameterVisitor, generator.network.back()); @@ -204,7 +204,7 @@ Gradient(const arma::mat& /*parameters*/, const size_t i, arma::mat& gradient) gradientDiscriminator += noiseGradientDiscriminator; - if (currentBatch % generatorUpdateStep == 0 && preTrainSize != 0) + if (currentBatch % generatorUpdateStep == 0 && preTrainSize == 0) { // Minimize log(1 - D(G(noise))) // pass the error from discriminator to generator @@ -220,12 +220,13 @@ Gradient(const arma::mat& /*parameters*/, const size_t i, arma::mat& gradient) gradientGenerator = -gradientGenerator; gradientGenerator *= multiplier; - +/* if (counter % batchSize == 0) { Log::Info << "gradientDiscriminator = " << std::max(std::fabs(gradientDiscriminator.min()), std::fabs(gradientDiscriminator.max())) << std::endl; Log::Info << "gradientGenerator = " << std::max(std::fabs(gradientGenerator.min()), std::fabs(gradientGenerator.max())) << std::endl; } +*/ } counter++; if (counter >= numFunctions) diff --git a/src/mlpack/tests/gan_test.cpp b/src/mlpack/tests/gan_test.cpp index d969ef0fb..0453041d8 100644 --- a/src/mlpack/tests/gan_test.cpp +++ b/src/mlpack/tests/gan_test.cpp @@ -36,6 +36,7 @@ BOOST_AUTO_TEST_CASE(GanTest) size_t dOutputSize = 1; size_t batchSize = 100; size_t noiseDim = 100; + size_t numSamples = 10; // Load the dataset arma::mat trainData, dataset, noiseData; @@ -70,9 +71,9 @@ BOOST_AUTO_TEST_CASE(GanTest) // Optimizer MiniBatchSGD optimizer(batchSize, 1e-4, 100 * trainData.n_cols, 1e-5, true); - std::normal_distribution<> noiseFunction(0.0, 1.0); + std::function noiseFunction = [] () { return math::RandNormal(0, 1); }; // GAN model - GAN<> gan(trainData, generator, discriminator, gaussian, noiseFunction, + GAN>, GaussianInitialization, std::function > gan(trainData, generator, discriminator, gaussian, noiseFunction, trainData.n_rows, batchSize, 10, 10); gan.Train(optimizer); @@ -99,7 +100,7 @@ BOOST_AUTO_TEST_CASE(GanTest) generatedData.submat(dim, i * dim, 2 * dim - 1, i * dim + dim - 1) = samples; } - std::string output_dataset = "./output_gan_ffn" + std::string output_dataset = "./output_gan_ffn"; Log::Info << "Saving output to " << output_dataset << "..." << std::endl; generatedData.save(output_dataset, arma::raw_ascii); Log::Info << "Output saved!" << std::endl;